From 2b84bc88941d13d954fdedb58deec57ede4f5cb9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 2 Jun 2022 14:36:22 +0800 Subject: [PATCH 1/4] [DLMED] enhance ensure_tuple Signed-off-by: Nic Ma --- monai/utils/misc.py | 6 ++--- tests/test_ensure_tuple.py | 49 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/test_ensure_tuple.py diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 521968a87d..86bbd484b5 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc import inspect import itertools import os @@ -19,6 +18,7 @@ import types import warnings from ast import literal_eval +from collections.abc import Iterable from distutils.util import strtobool from pathlib import Path from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast @@ -90,14 +90,14 @@ def issequenceiterable(obj: Any) -> bool: """ if isinstance(obj, torch.Tensor): return int(obj.dim()) > 0 # a 0-d tensor is not iterable - return isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)) + return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) def ensure_tuple(vals: Any) -> Tuple[Any, ...]: """ Returns a tuple of `vals`. """ - if not issequenceiterable(vals): + if not issequenceiterable(vals) or isinstance(vals, (torch.Tensor, np.ndarray)): return (vals,) return tuple(vals) diff --git a/tests/test_ensure_tuple.py b/tests/test_ensure_tuple.py new file mode 100644 index 0000000000..e9ba0d4f88 --- /dev/null +++ b/tests/test_ensure_tuple.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.utils.misc import ensure_tuple +from tests.utils import assert_allclose + +TESTS = [ + ["test", ("test",)], + [["test1", "test2"], ("test1", "test2")], + [123, (123,)], + [(1, 2, 3), (1, 2, 3)], + [np.array([1, 2]), (np.array([1, 2]),)], + [torch.tensor([1, 2]), (torch.tensor([1, 2]),)], + [np.array([]), (np.array([]),)], + [torch.tensor([]), (torch.tensor([]),)], + [np.array(123), (np.array(123),)], + [torch.tensor(123), (torch.tensor(123),)], +] + + +class TestEnsureTuple(unittest.TestCase): + @parameterized.expand(TESTS) + def test_value(self, input, expected_value): + result = ensure_tuple(input) + self.assertTrue(isinstance(result, tuple)) + if isinstance(input, (np.ndarray, torch.Tensor)): + for i, j in zip(result, expected_value): + assert_allclose(i, j) + else: + self.assertTupleEqual(result, expected_value) + + +if __name__ == "__main__": + + unittest.main() From e373e31c89f4d9380f21831a722e0674ce68782f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 2 Jun 2022 15:48:03 +0800 Subject: [PATCH 2/4] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/utils/misc.py | 10 ++++++++-- tests/test_ensure_tuple.py | 12 +++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 86bbd484b5..feb9c151bb 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -93,11 +93,17 @@ def issequenceiterable(obj: Any) -> bool: return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) -def ensure_tuple(vals: Any) -> Tuple[Any, ...]: +def ensure_tuple(vals: Any, wrap_array: bool = False) -> Tuple[Any, ...]: """ Returns a tuple of `vals`. + + Args: + vals: input data to convert to a tuple. + wrap_array: if `True`, treat the whole input array as one item of the tuple. + if `False`, automatically convert the input array with `tuple(vals)`, default to `False`. + """ - if not issequenceiterable(vals) or isinstance(vals, (torch.Tensor, np.ndarray)): + if not issequenceiterable(vals) or wrap_array: return (vals,) return tuple(vals) diff --git a/tests/test_ensure_tuple.py b/tests/test_ensure_tuple.py index e9ba0d4f88..cc427da5ab 100644 --- a/tests/test_ensure_tuple.py +++ b/tests/test_ensure_tuple.py @@ -23,19 +23,21 @@ [["test1", "test2"], ("test1", "test2")], [123, (123,)], [(1, 2, 3), (1, 2, 3)], - [np.array([1, 2]), (np.array([1, 2]),)], - [torch.tensor([1, 2]), (torch.tensor([1, 2]),)], + [np.array([1, 2]), (np.array([1, 2]),), True], + [np.array([1, 2]), (1, 2), False], + [torch.tensor([1, 2]), (torch.tensor([1, 2]),), True], [np.array([]), (np.array([]),)], [torch.tensor([]), (torch.tensor([]),)], - [np.array(123), (np.array(123),)], + [np.array(123), (np.array(123),), True], [torch.tensor(123), (torch.tensor(123),)], ] class TestEnsureTuple(unittest.TestCase): @parameterized.expand(TESTS) - def test_value(self, input, expected_value): - result = ensure_tuple(input) + def test_value(self, input, expected_value, wrap_array=False): + result = ensure_tuple(input, wrap_array) + self.assertTrue(isinstance(result, tuple)) if isinstance(input, (np.ndarray, torch.Tensor)): for i, j in zip(result, expected_value): From 2c7ecbb6cccd714ac828b20256596830ab85c83e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 2 Jun 2022 10:02:31 +0100 Subject: [PATCH 3/4] update Signed-off-by: Wenqi Li --- monai/utils/misc.py | 18 +++++++++++------- tests/test_ensure_tuple.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index feb9c151bb..5e659efe7c 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -88,8 +88,13 @@ def issequenceiterable(obj: Any) -> bool: """ Determine if the object is an iterable sequence and is not a string. """ - if isinstance(obj, torch.Tensor): - return int(obj.dim()) > 0 # a 0-d tensor is not iterable + try: + if hasattr(obj, "ndim") and obj.ndim == 0: + return False + if isinstance(obj, torch.Tensor): + return int(obj.dim()) > 0 # a 0-d tensor is not iterable + except Exception: + return False return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) @@ -99,14 +104,13 @@ def ensure_tuple(vals: Any, wrap_array: bool = False) -> Tuple[Any, ...]: Args: vals: input data to convert to a tuple. - wrap_array: if `True`, treat the whole input array as one item of the tuple. - if `False`, automatically convert the input array with `tuple(vals)`, default to `False`. + wrap_array: if `True`, treat the input numerical array (ndarray/tensor) as one item of the tuple. + if `False`, try to convert the array with `tuple(vals)`, default to `False`. """ - if not issequenceiterable(vals) or wrap_array: + if wrap_array and isinstance(vals, (np.ndarray, torch.Tensor)): return (vals,) - - return tuple(vals) + return tuple(vals) if issequenceiterable(vals) else (vals,) def ensure_tuple_size(tup: Any, dim: int, pad_val: Any = 0) -> Tuple[Any, ...]: diff --git a/tests/test_ensure_tuple.py b/tests/test_ensure_tuple.py index cc427da5ab..ea580871da 100644 --- a/tests/test_ensure_tuple.py +++ b/tests/test_ensure_tuple.py @@ -22,7 +22,8 @@ ["test", ("test",)], [["test1", "test2"], ("test1", "test2")], [123, (123,)], - [(1, 2, 3), (1, 2, 3)], + [(1, [2], 3), (1, [2], 3)], + [(1, 2, 3), (1, 2, 3), True], [np.array([1, 2]), (np.array([1, 2]),), True], [np.array([1, 2]), (1, 2), False], [torch.tensor([1, 2]), (torch.tensor([1, 2]),), True], From 07e2ee485699ef615f5f5d407279cb192094e70d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 2 Jun 2022 10:41:11 +0100 Subject: [PATCH 4/4] update cond Signed-off-by: Wenqi Li --- monai/utils/misc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 5e659efe7c..99f645a704 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -90,9 +90,7 @@ def issequenceiterable(obj: Any) -> bool: """ try: if hasattr(obj, "ndim") and obj.ndim == 0: - return False - if isinstance(obj, torch.Tensor): - return int(obj.dim()) > 0 # a 0-d tensor is not iterable + return False # a 0-d tensor is not iterable except Exception: return False return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes))