Skip to content
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,12 @@ Utility
:members:
:special-members: __call__

`MapLabelValue`
"""""""""""""""
.. autoclass:: MapLabelValue
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -1052,6 +1058,12 @@ Utility (Dict)
:members:
:special-members: __call__

`MapLabelValued`
""""""""""""""""
.. autoclass:: MapLabelValued
:members:
:special-members: __call__

Transform Adaptors
------------------
.. automodule:: monai.transforms.adaptors
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@
Identity,
LabelToMask,
Lambda,
MapLabelValue,
RemoveRepeatedChannel,
RepeatChannel,
SimulateDelay,
Expand Down Expand Up @@ -325,6 +326,9 @@
Lambdad,
LambdaD,
LambdaDict,
MapLabelValued,
MapLabelValueD,
MapLabelValueDict,
RandLambdad,
RandLambdaD,
RandLambdaDict,
Expand Down
41 changes: 41 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"ConvertToMultiChannelBasedOnBratsClasses",
"AddExtremePointsChannel",
"TorchVision",
"MapLabelValue",
]


Expand Down Expand Up @@ -758,3 +759,43 @@ def __call__(self, img: torch.Tensor):

"""
return self.trans(img)


class MapLabelValue:
"""
Utility to map label values to another set of values.
For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2],
[3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc.
The label data must be numpy array or array-like data and the output data will be numpy array.

"""

def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None:
"""
Args:
orig_labels: original labels that map to others.
target_labels: expected label values, 1: 1 map to the `orig_labels`.
dtype: convert the output data to dtype, default to float32.

"""
if len(orig_labels) != len(target_labels):
raise ValueError("orig_labels and target_labels must have the same length.")
self.orig_labels = orig_labels
self.target_labels = target_labels
self.dtype = dtype

def __call__(self, img: np.ndarray):
img = np.asarray(img)
img_flat = img.flatten()
try:
out_flat = np.copy(img_flat).astype(self.dtype)
except ValueError:
# can't copy unchanged labels as the expected dtype is not supported, must map all the label values
out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype)
Comment thread
Nic-Ma marked this conversation as resolved.

for o, t in zip(self.orig_labels, self.target_labels):
if o == t:
continue
np.place(out_flat, img_flat == o, t)

return out_flat.reshape(img.shape)
38 changes: 38 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Identity,
LabelToMask,
Lambda,
MapLabelValue,
RemoveRepeatedChannel,
RepeatChannel,
SimulateDelay,
Expand Down Expand Up @@ -83,6 +84,7 @@
"ConvertToMultiChannelBasedOnBratsClassesd",
"AddExtremePointsChanneld",
"TorchVisiond",
"MapLabelValued",
"IdentityD",
"IdentityDict",
"AsChannelFirstD",
Expand Down Expand Up @@ -129,6 +131,8 @@
"AddExtremePointsChannelDict",
"TorchVisionD",
"TorchVisionDict",
"MapLabelValueD",
"MapLabelValueDict",
]


Expand Down Expand Up @@ -960,6 +964,39 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
return d


class MapLabelValued(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
"""

def __init__(
self,
keys: KeysCollection,
orig_labels: Sequence,
target_labels: Sequence,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
orig_labels: original labels that map to others.
target_labels: expected label values, 1: 1 map to the `orig_labels`.
dtype: convert the output data to dtype, default to float32.
allow_missing_keys: don't raise exception if key is missing.

"""
super().__init__(keys, allow_missing_keys)
self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.mapper(d[key])
return d


IdentityD = IdentityDict = Identityd
AsChannelFirstD = AsChannelFirstDict = AsChannelFirstd
AsChannelLastD = AsChannelLastDict = AsChannelLastd
Expand Down Expand Up @@ -987,3 +1024,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
RandLambdaD = RandLambdaDict = RandLambdad
MapLabelValueD = MapLabelValueDict = MapLabelValued
71 changes: 71 additions & 0 deletions tests/test_map_label_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import MapLabelValue

TEST_CASE_1 = [
{"orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]},
np.array([[3, 1], [1, 2]]),
np.array([[0, 2], [2, 1]]),
]

TEST_CASE_2 = [
{"orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]},
np.array([[[3], [5], [5], [8]]]),
np.array([[[0], [1], [1], [2]]]),
]

TEST_CASE_3 = [
{"orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]},
np.array([3, 1, 1, 2]),
np.array([2, 0, 0, 1]),
]

TEST_CASE_4 = [
{"orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]},
np.array([3, 1, 1, 2]),
np.array([2.5, 0.5, 0.5, 1.5]),
]

TEST_CASE_5 = [
{"orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8},
np.array([3.5, 1.5, 1.5, 2.5]),
np.array([2, 0, 0, 1]),
]

TEST_CASE_6 = [
{"orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]},
np.array([["label3", "label1"], ["label1", "label2"]]),
np.array([[0, 2], [2, 1]]),
]

TEST_CASE_7 = [
{"orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"},
np.array([[3.5, 1.5], [1.5, 2.5]]),
np.array([["label0", "label2"], ["label2", "label1"]]),
]


class TestMapLabelValue(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_shape(self, input_param, input_data, expected_value):
result = MapLabelValue(**input_param)(input_data)
np.testing.assert_equal(result, expected_value)
self.assertTupleEqual(result.shape, expected_value.shape)


if __name__ == "__main__":
unittest.main()
71 changes: 71 additions & 0 deletions tests/test_map_label_valued.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import MapLabelValued

TEST_CASE_1 = [
{"keys": "seg", "orig_labels": [3, 2, 1], "target_labels": [0, 1, 2]},
{"seg": np.array([[3, 1], [1, 2]])},
np.array([[0, 2], [2, 1]]),
]

TEST_CASE_2 = [
{"keys": "seg", "orig_labels": [3, 5, 8], "target_labels": [0, 1, 2]},
{"seg": np.array([[[3], [5], [5], [8]]])},
np.array([[[0], [1], [1], [2]]]),
]

TEST_CASE_3 = [
{"keys": "seg", "orig_labels": [1, 2, 3], "target_labels": [0, 1, 2]},
{"seg": np.array([3, 1, 1, 2])},
np.array([2, 0, 0, 1]),
]

TEST_CASE_4 = [
{"keys": "seg", "orig_labels": [1, 2, 3], "target_labels": [0.5, 1.5, 2.5]},
{"seg": np.array([3, 1, 1, 2])},
np.array([2.5, 0.5, 0.5, 1.5]),
]

TEST_CASE_5 = [
{"keys": "seg", "orig_labels": [1.5, 2.5, 3.5], "target_labels": [0, 1, 2], "dtype": np.int8},
{"seg": np.array([3.5, 1.5, 1.5, 2.5])},
np.array([2, 0, 0, 1]),
]

TEST_CASE_6 = [
{"keys": "seg", "orig_labels": ["label3", "label2", "label1"], "target_labels": [0, 1, 2]},
{"seg": np.array([["label3", "label1"], ["label1", "label2"]])},
np.array([[0, 2], [2, 1]]),
]

TEST_CASE_7 = [
{"keys": "seg", "orig_labels": [3.5, 2.5, 1.5], "target_labels": ["label0", "label1", "label2"], "dtype": "str"},
{"seg": np.array([[3.5, 1.5], [1.5, 2.5]])},
np.array([["label0", "label2"], ["label2", "label1"]]),
]


class TestMapLabelValued(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_shape(self, input_param, input_data, expected_value):
result = MapLabelValued(**input_param)(input_data)
np.testing.assert_equal(result["seg"], expected_value)
self.assertTupleEqual(result["seg"].shape, expected_value.shape)


if __name__ == "__main__":
unittest.main()