From 421f79017de8c7439dbb2985872cda3b2cd6c11e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Jul 2021 17:52:31 +0800 Subject: [PATCH 1/3] [DLMED] fix tests for PyTorch 1.5 Signed-off-by: Nic Ma --- tests/test_ensure_type.py | 9 ++++++++- tests/test_ensure_typed.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 5941b5ce23..4c11230d7e 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -27,13 +27,20 @@ def test_array_input(self): self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): + for test_data in (5, 5.0, np.asarray(5), torch.tensor(5)): for dtype in ("tensor", "numpy"): result = EnsureType(data_type=dtype)(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) + def test_bool_input(self): + for dtype in ("tensor", "numpy"): + result = EnsureType(data_type=dtype)(data=False) + self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertFalse(result) + self.assertEqual(result.ndim, 0) + def test_string(self): for dtype in ("tensor", "numpy"): # string input diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 58fa78d102..df7eaaf103 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -27,13 +27,20 @@ def test_array_input(self): self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): + for test_data in (5, 5.0, np.asarray(5), torch.tensor(5)): for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) torch.testing.assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) + def test_bool_input(self): + for dtype in ("tensor", "numpy"): + result = EnsureTyped(keys="data", data_type=dtype)(data={"data": False})["data"] + self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertFalse(result) + self.assertEqual(result.ndim, 0) + def test_string(self): for dtype in ("tensor", "numpy"): # string input From 13b39bc91ee4c31439efd8afe4a311e2f2b3aa76 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Jul 2021 18:03:00 +0800 Subject: [PATCH 2/3] [DLMED] simplify the tests Signed-off-by: Nic Ma --- tests/test_ensure_type.py | 14 +++++--------- tests/test_ensure_typed.py | 14 +++++--------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 4c11230d7e..11cf6760fb 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -27,20 +27,16 @@ def test_array_input(self): self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, np.asarray(5), torch.tensor(5)): + for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): for dtype in ("tensor", "numpy"): result = EnsureType(data_type=dtype)(test_data) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - torch.testing.assert_allclose(result, test_data) + if isinstance(test_data, bool): + self.assertFalse(result) + else: + torch.testing.assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) - def test_bool_input(self): - for dtype in ("tensor", "numpy"): - result = EnsureType(data_type=dtype)(data=False) - self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - self.assertFalse(result) - self.assertEqual(result.ndim, 0) - def test_string(self): for dtype in ("tensor", "numpy"): # string input diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index df7eaaf103..c5f588d423 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -27,20 +27,16 @@ def test_array_input(self): self.assertTupleEqual(result.shape, (2, 2)) def test_single_input(self): - for test_data in (5, 5.0, np.asarray(5), torch.tensor(5)): + for test_data in (5, 5.0, False, np.asarray(5), torch.tensor(5)): for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - torch.testing.assert_allclose(result, test_data) + if isinstance(test_data, bool): + self.assertFalse(result) + else: + torch.testing.assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) - def test_bool_input(self): - for dtype in ("tensor", "numpy"): - result = EnsureTyped(keys="data", data_type=dtype)(data={"data": False})["data"] - self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) - self.assertFalse(result) - self.assertEqual(result.ndim, 0) - def test_string(self): for dtype in ("tensor", "numpy"): # string input From b305cb37fd2c9c3c67b42e4d20e3934ddcf9ec7b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 8 Jul 2021 12:04:22 +0100 Subject: [PATCH 3/3] remove dev quick tests which is mostly duplicated Signed-off-by: Wenqi Li --- .github/workflows/pythonapp.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index b2ddb74d34..f1ab392206 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -4,7 +4,6 @@ on: # quick tests for pull requests and the releasing branches push: branches: - - dev - main - releasing/* pull_request: