From 8cab9b6b7f24638b033c354e489f8dd4d61d9038 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 14 Aug 2021 16:34:47 -0500 Subject: [PATCH 1/2] defines DictOfNamedArrays.__eq__ --- pytato/array.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytato/array.py b/pytato/array.py index 25ffc525e..8fe6ea49e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -815,6 +815,13 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: return iter(self._data) + def __eq__(self, other: Any) -> bool: + if self is other: + return True + + return (isinstance(other, DictOfNamedArrays) + and self._data == other._data) + # }}} From 5b9929acf39646f1afd176da1a34189e948c069d Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 14 Aug 2021 16:35:13 -0500 Subject: [PATCH 2/2] test equality comparison of dict-of-named-arrays --- test/test_pytato.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index b77c08913..5c3fdb519 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -266,6 +266,18 @@ class BestArrayTag(Tag): assert any(isinstance(tag, BestArrayTag) for tag in y.tags) +def test_dict_of_named_arrays_comparison(): + # See https://github.com/inducer/pytato/pull/137 + x = pt.make_placeholder("x", (10, 4), float) + dict1 = pt.make_dict_of_named_arrays({"out": 2 * x}) + dict2 = pt.make_dict_of_named_arrays({"out": 2 * x}) + dict3 = pt.make_dict_of_named_arrays({"not_out": 2 * x}) + dict4 = pt.make_dict_of_named_arrays({"out": 3 * x}) + assert dict1 == dict2 + assert dict1 != dict3 + assert dict1 != dict4 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])