diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 768c0665a2..28bfdc5f24 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -551,6 +551,12 @@ Utility :members: :special-members: __call__ +`MapLabelValue` +""""""""""""""" +.. autoclass:: MapLabelValue + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -1052,6 +1058,12 @@ Utility (Dict) :members: :special-members: __call__ +`MapLabelValued` +"""""""""""""""" +.. autoclass:: MapLabelValued + :members: + :special-members: __call__ + Transform Adaptors ------------------ .. automodule:: monai.transforms.adaptors diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 22311cdca6..c7b60e15e3 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -268,6 +268,7 @@ Identity, LabelToMask, Lambda, + MapLabelValue, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -325,6 +326,9 @@ Lambdad, LambdaD, LambdaDict, + MapLabelValued, + MapLabelValueD, + MapLabelValueDict, RandLambdad, RandLambdaD, RandLambdaDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f169002596..987542c979 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -57,6 +57,7 @@ "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", + "MapLabelValue", ] @@ -758,3 +759,43 @@ def __call__(self, img: torch.Tensor): """ return self.trans(img) + + +class MapLabelValue: + """ + Utility to map label values to another set of values. + For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2], + [3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc. + The label data must be numpy array or array-like data and the output data will be numpy array. + + """ + + def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: + """ + Args: + orig_labels: original labels that map to others. + target_labels: expected label values, 1: 1 map to the `orig_labels`. + dtype: convert the output data to dtype, default to float32. + + """ + if len(orig_labels) != len(target_labels): + raise ValueError("orig_labels and target_labels must have the same length.") + self.orig_labels = orig_labels + self.target_labels = target_labels + self.dtype = dtype + + def __call__(self, img: np.ndarray): + img = np.asarray(img) + img_flat = img.flatten() + try: + out_flat = np.copy(img_flat).astype(self.dtype) + except ValueError: + # can't copy unchanged labels as the expected dtype is not supported, must map all the label values + out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) + + for o, t in zip(self.orig_labels, self.target_labels): + if o == t: + continue + np.place(out_flat, img_flat == o, t) + + return out_flat.reshape(img.shape) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 324835a874..e7cf63e210 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -36,6 +36,7 @@ Identity, LabelToMask, Lambda, + MapLabelValue, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -83,6 +84,7 @@ "ConvertToMultiChannelBasedOnBratsClassesd", "AddExtremePointsChanneld", "TorchVisiond", + "MapLabelValued", "IdentityD", "IdentityDict", "AsChannelFirstD", @@ -129,6 +131,8 @@ "AddExtremePointsChannelDict", "TorchVisionD", "TorchVisionDict", + "MapLabelValueD", + "MapLabelValueDict", ] @@ -960,6 +964,39 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class MapLabelValued(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. + """ + + def __init__( + self, + keys: KeysCollection, + orig_labels: Sequence, + target_labels: Sequence, + dtype: DtypeLike = np.float32, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + orig_labels: original labels that map to others. + target_labels: expected label values, 1: 1 map to the `orig_labels`. + dtype: convert the output data to dtype, default to float32. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype) + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.mapper(d[key]) + return d + + IdentityD = IdentityDict = Identityd AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd AsChannelLastD = AsChannelLastDict = AsChannelLastd @@ -987,3 +1024,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond RandLambdaD = RandLambdaDict = RandLambdad +MapLabelValueD = MapLabelValueDict = MapLabelValued diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py new file mode 100644 index 0000000000..98412ab800 --- /dev/null +++ b/tests/test_map_label_value.py @@ -0,0 +1,71 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import MapLabelValue + +TEST_CASE_1 = [ + {"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, + np.array([[3, 1], [1, 2]]), + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_2 = [ + {"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + np.array([[[3], [5], [5], [8]]]), + np.array([[[0], [1], [1], [2]]]), +] + +TEST_CASE_3 = [ + {"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, + np.array([3, 1, 1, 2]), + np.array([2, 0, 0, 1]), +] + +TEST_CASE_4 = [ + {"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, + np.array([3, 1, 1, 2]), + np.array([2.5, 0.5, 0.5, 1.5]), +] + +TEST_CASE_5 = [ + {"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, + np.array([3.5, 1.5, 1.5, 2.5]), + np.array([2, 0, 0, 1]), +] + +TEST_CASE_6 = [ + {"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, + np.array([["label3", "label1"], ["label1", "label2"]]), + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_7 = [ + {"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, + np.array([[3.5, 1.5], [1.5, 2.5]]), + np.array([["label0", "label2"], ["label2", "label1"]]), +] + + +class TestMapLabelValue(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + def test_shape(self, input_param, input_data, expected_value): + result = MapLabelValue(**input_param)(input_data) + np.testing.assert_equal(result, expected_value) + self.assertTupleEqual(result.shape, expected_value.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_map_label_valued.py b/tests/test_map_label_valued.py new file mode 100644 index 0000000000..426ac28836 --- /dev/null +++ b/tests/test_map_label_valued.py @@ -0,0 +1,71 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import MapLabelValued + +TEST_CASE_1 = [ + {"keys": "seg", "orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]}, + {"seg": np.array([[3, 1], [1, 2]])}, + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_2 = [ + {"keys": "seg", "orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]}, + {"seg": np.array([[[3], [5], [5], [8]]])}, + np.array([[[0], [1], [1], [2]]]), +] + +TEST_CASE_3 = [ + {"keys": "seg", "orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]}, + {"seg": np.array([3, 1, 1, 2])}, + np.array([2, 0, 0, 1]), +] + +TEST_CASE_4 = [ + {"keys": "seg", "orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]}, + {"seg": np.array([3, 1, 1, 2])}, + np.array([2.5, 0.5, 0.5, 1.5]), +] + +TEST_CASE_5 = [ + {"keys": "seg", "orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8}, + {"seg": np.array([3.5, 1.5, 1.5, 2.5])}, + np.array([2, 0, 0, 1]), +] + +TEST_CASE_6 = [ + {"keys": "seg", "orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]}, + {"seg": np.array([["label3", "label1"], ["label1", "label2"]])}, + np.array([[0, 2], [2, 1]]), +] + +TEST_CASE_7 = [ + {"keys": "seg", "orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"}, + {"seg": np.array([[3.5, 1.5], [1.5, 2.5]])}, + np.array([["label0", "label2"], ["label2", "label1"]]), +] + + +class TestMapLabelValued(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + def test_shape(self, input_param, input_data, expected_value): + result = MapLabelValued(**input_param)(input_data) + np.testing.assert_equal(result["seg"], expected_value) + self.assertTupleEqual(result["seg"].shape, expected_value.shape) + + +if __name__ == "__main__": + unittest.main()