From dc9229daabcc9130627e0a54ddd51a3f45cf261c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 26 Mar 2021 20:07:20 +0800 Subject: [PATCH 1/5] [DLMED] add MapLabelValue Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f169002596..47afc23301 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -758,3 +758,32 @@ def __call__(self, img: torch.Tensor): """ return self.trans(img) + + +class MapLabelValue: + """ + This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. + As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input + data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + + """ + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: torch.Tensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + + """ + return self.trans(img) From 62243eaa21c9e855c9978d6ed4066d38ff4e63dc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 26 Mar 2021 21:28:31 +0800 Subject: [PATCH 2/5] [DLMED] add unit tests Signed-off-by: Nic Ma --- docs/source/transforms.rst | 12 +++++++ monai/transforms/__init__.py | 4 +++ monai/transforms/utility/array.py | 35 ++++++++++--------- monai/transforms/utility/dictionary.py | 35 +++++++++++++++++++ tests/test_map_label_value.py | 47 ++++++++++++++++++++++++++ tests/test_map_label_valued.py | 47 ++++++++++++++++++++++++++ 6 files changed, 164 insertions(+), 16 deletions(-) create mode 100644 tests/test_map_label_value.py create mode 100644 tests/test_map_label_valued.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 768c0665a2..28bfdc5f24 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -551,6 +551,12 @@ Utility :members: :special-members: __call__ +`MapLabelValue` +""""""""""""""" +.. autoclass:: MapLabelValue + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -1052,6 +1058,12 @@ Utility (Dict) :members: :special-members: __call__ +`MapLabelValued` +"""""""""""""""" +.. autoclass:: MapLabelValued + :members: + :special-members: __call__ + Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 22311cdca6..c7b60e15e3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -268,6 +268,7 @@ Identity, LabelToMask, Lambda, + MapLabelValue, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -325,6 +326,9 @@ Lambdad, LambdaD, LambdaDict, + MapLabelValued, + MapLabelValueD, + MapLabelValueDict, RandLambdad, RandLambdaD, RandLambdaDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 47afc23301..34f5d7a8d8 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -57,6 +57,7 @@ "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", + "MapLabelValue", ] @@ -762,28 +763,30 @@ def __call__(self, img: torch.Tensor): class MapLabelValue: """ - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. - As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input - data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + Utility to map label values to another set of values. + For example, map [3, 2, 1] to [0, 1, 2], [3, 5, 8] -> [1, 2, 3], etc. """ - def __init__(self, name: str, *args, **kwargs) -> None: + def __init__(self, orig_labels: Sequence[str], target_labels: Sequence[str]) -> None: """ Args: - name: The transform name in TorchVision package. - args: parameters for the TorchVision transform. - kwargs: parameters for the TorchVision transform. + orig_labels: original labels that map to others. + target_labels: expected label values, 1: 1 map to the `orig_labels`. """ - super().__init__() - transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) - self.trans = transform(*args, **kwargs) + if len(orig_labels) != len(target_labels): + raise ValueError("orig_labels and target_labels must have the same length.") + self.orig_labels = orig_labels + self.target_labels = target_labels - def __call__(self, img: torch.Tensor): - """ - Args: - img: PyTorch Tensor data for the TorchVision transform. + def __call__(self, img: np.ndarray): + img_flat = img.flatten() + out_flat = np.copy(img_flat) - """ - return self.trans(img) + for o, t in zip(self.orig_labels, self.target_labels): + if o == t: + continue + np.place(out_flat, img_flat == o, t) + + return out_flat.reshape(img.shape) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 324835a874..7cfd34eb06 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -36,6 +36,7 @@ Identity, LabelToMask, Lambda, + MapLabelValue, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -83,6 +84,7 @@ "ConvertToMultiChannelBasedOnBratsClassesd", "AddExtremePointsChanneld", "TorchVisiond", + "MapLabelValued", "IdentityD", "IdentityDict", "AsChannelFirstD", @@ -129,6 +131,8 @@ "AddExtremePointsChannelDict", "TorchVisionD", "TorchVisionDict", + "MapLabelValueD", + "MapLabelValueDict", ] @@ -960,6 +964,36 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class MapLabelValued(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. + """ + + def __init__( + self, + keys: KeysCollection, + orig_labels: Sequence[str], + target_labels: Sequence[str], + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + orig_labels: original labels that map to others. + target_labels: expected label values, 1: 1 map to the `orig_labels`. + + """ + super().__init__(keys, allow_missing_keys) + self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.mapper(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -987,3 +1021,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond RandLambdaD = RandLambdaDict = RandLambdad +MapLabelValueD = MapLabelValueDict = MapLabelValued diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py new file mode 100644 index 0000000000..f8041b2b93 --- /dev/null +++ b/tests/test_map_label_value.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import MapLabelValue + +TEST_CASE_1 = [ + {"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, + np.array([[3, 1], [1, 2]]), + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_2 = [ + {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + np.array([[[3], [5], [5], [8]]]), + np.array([[[0], [1], [1], [2]]]), +] + +TEST_CASE_3 = [ + {"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, + np.array([3, 1, 1, 2]), + np.array([2, 0, 0, 1]), +] + + +class TestMapLabelValue(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_param, input_data, expected_value): + result = MapLabelValue(**input_param)(input_data) + np.testing.assert_allclose(result, expected_value) + self.assertTupleEqual(result.shape, expected_value.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py new file mode 100644 index 0000000000..973eb3a455 --- /dev/null +++ b/tests/test_map_label_valued.py @@ -0,0 +1,47 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import MapLabelValued + +TEST_CASE_1 = [ + {"keys": "label", "orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, + {"label": np.array([[3, 1], [1, 2]])}, + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_2 = [ + {"keys": "label", "orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + {"label": np.array([[[3], [5], [5], [8]]])}, + np.array([[[0], [1], [1], [2]]]), +] + +TEST_CASE_3 = [ + {"keys": "label", "orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, + {"label": np.array([3, 1, 1, 2])}, + np.array([2, 0, 0, 1]), +] + + +class TestMapLabelValued(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_param, input_data, expected_value): + result = MapLabelValued(**input_param)(input_data) + np.testing.assert_allclose(result["label"], expected_value) + self.assertTupleEqual(result["label"].shape, expected_value.shape) + + +if __name__ == "__main__": + unittest.main() From 6e69ebb7f80179ae8847a7d11697626bd1db111c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 26 Mar 2021 21:31:21 +0800 Subject: [PATCH 3/5] [DLMED] add missing doc-string Signed-off-by: Nic Ma --- monai/transforms/utility/dictionary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7cfd34eb06..287f21cca2 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -982,6 +982,7 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` orig_labels: original labels that map to others. target_labels: expected label values, 1: 1 map to the `orig_labels`. + allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) From 162fb2fe121089f7dbd8b1c61b9e8bac5500d00d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 27 Mar 2021 00:20:52 +0800 Subject: [PATCH 4/5] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/utility/array.py | 15 +++++++-- monai/transforms/utility/dictionary.py | 8 +++-- tests/test_map_label_value.py | 28 +++++++++++++++-- tests/test_map_label_valued.py | 42 ++++++++++++++++++++------ 4 files changed, 76 insertions(+), 17 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 34f5d7a8d8..987542c979 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -764,25 +764,34 @@ def __call__(self, img: torch.Tensor): class MapLabelValue: """ Utility to map label values to another set of values. - For example, map [3, 2, 1] to [0, 1, 2], [3, 5, 8] -> [1, 2, 3], etc. + For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2], + [3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc. + The label data must be numpy array or array-like data and the output data will be numpy array. """ - def __init__(self, orig_labels: Sequence[str], target_labels: Sequence[str]) -> None: + def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: """ Args: orig_labels: original labels that map to others. target_labels: expected label values, 1: 1 map to the `orig_labels`. + dtype: convert the output data to dtype, default to float32. """ if len(orig_labels) != len(target_labels): raise ValueError("orig_labels and target_labels must have the same length.") self.orig_labels = orig_labels self.target_labels = target_labels + self.dtype = dtype def __call__(self, img: np.ndarray): + img = np.asarray(img) img_flat = img.flatten() - out_flat = np.copy(img_flat) + try: + out_flat = np.copy(img_flat).astype(self.dtype) + except ValueError: + # can't copy unchanged labels as the expected dtype is not supported, must map all the label values + out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) for o, t in zip(self.orig_labels, self.target_labels): if o == t: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 287f21cca2..e7cf63e210 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -972,8 +972,9 @@ class MapLabelValued(MapTransform): def __init__( self, keys: KeysCollection, - orig_labels: Sequence[str], - target_labels: Sequence[str], + orig_labels: Sequence, + target_labels: Sequence, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -982,11 +983,12 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` orig_labels: original labels that map to others. target_labels: expected label values, 1: 1 map to the `orig_labels`. + dtype: convert the output data to dtype, default to float32. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels) + self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index f8041b2b93..d8c18537f8 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -34,12 +34,36 @@ np.array([2, 0, 0, 1]), ] +TEST_CASE_4 = [ + {"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, + np.array([3, 1, 1, 2]), + np.array([2.5, 0.5, 0.5, 1.5]), +] + +TEST_CASE_5 = [ + {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, + np.array([3.5, 1.5, 1.5, 2.5]), + np.array([2, 0, 0, 1]), +] + +TEST_CASE_6 = [ + {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_7 = [ + {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": np.str}, + np.array([[3.5, 1.5], [1.5, 2.5]]), + np.array([["label0", "label2"], ["label2", "label1"]]), +] + class TestMapLabelValue(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_shape(self, input_param, input_data, expected_value): result = MapLabelValue(**input_param)(input_data) - np.testing.assert_allclose(result, expected_value) + np.testing.assert_equal(result, expected_value) self.assertTupleEqual(result.shape, expected_value.shape) diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py index 973eb3a455..9590b149d3 100644 --- a/tests/test_map_label_valued.py +++ b/tests/test_map_label_valued.py @@ -17,30 +17,54 @@ from monai.transforms import MapLabelValued TEST_CASE_1 = [ - {"keys": "label", "orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, - {"label": np.array([[3, 1], [1, 2]])}, + {"keys": "seg", "orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, + {"seg": np.array([[3, 1], [1, 2]])}, np.array([[0, 2], [2, 1]]), ] TEST_CASE_2 = [ - {"keys": "label", "orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, - {"label": np.array([[[3], [5], [5], [8]]])}, + {"keys": "seg", "orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + {"seg": np.array([[[3], [5], [5], [8]]])}, np.array([[[0], [1], [1], [2]]]), ] TEST_CASE_3 = [ - {"keys": "label", "orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, - {"label": np.array([3, 1, 1, 2])}, + {"keys": "seg", "orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, + {"seg": np.array([3, 1, 1, 2])}, np.array([2, 0, 0, 1]), ] +TEST_CASE_4 = [ + {"keys": "seg", "orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, + {"seg": np.array([3, 1, 1, 2])}, + np.array([2.5, 0.5, 0.5, 1.5]), +] + +TEST_CASE_5 = [ + {"keys": "seg", "orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, + {"seg": np.array([3.5, 1.5, 1.5, 2.5])}, + np.array([2, 0, 0, 1]), +] + +TEST_CASE_6 = [ + {"keys": "seg", "orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, + {"seg": np.array([["label3", "label1"], ["label1", "label2"]])}, + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_7 = [ + {"keys": "seg", "orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": np.str}, + {"seg": np.array([[3.5, 1.5], [1.5, 2.5]])}, + np.array([["label0", "label2"], ["label2", "label1"]]), +] + class TestMapLabelValued(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_shape(self, input_param, input_data, expected_value): result = MapLabelValued(**input_param)(input_data) - np.testing.assert_allclose(result["label"], expected_value) - self.assertTupleEqual(result["label"].shape, expected_value.shape) + np.testing.assert_equal(result["seg"], expected_value) + self.assertTupleEqual(result["seg"].shape, expected_value.shape) if __name__ == "__main__": From 1b8cfa29f52923a2d896a901acb65d9d5273128a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 27 Mar 2021 07:58:00 +0800 Subject: [PATCH 5/5] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- tests/test_map_label_value.py | 2 +- tests/test_map_label_valued.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py index d8c18537f8..98412ab800 100644 --- a/tests/test_map_label_value.py +++ b/tests/test_map_label_value.py @@ -53,7 +53,7 @@ ] TEST_CASE_7 = [ - {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": np.str}, + {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, np.array([[3.5, 1.5], [1.5, 2.5]]), np.array([["label0", "label2"], ["label2", "label1"]]), ] diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py index 9590b149d3..426ac28836 100644 --- a/tests/test_map_label_valued.py +++ b/tests/test_map_label_valued.py @@ -53,7 +53,7 @@ ] TEST_CASE_7 = [ - {"keys": "seg", "orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": np.str}, + {"keys": "seg", "orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, {"seg": np.array([[3.5, 1.5], [1.5, 2.5]])}, np.array([["label0", "label2"], ["label2", "label1"]]), ]