diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index bf5ed2b180..2eb2537b49 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -446,6 +446,15 @@ Intensity :members: :special-members: __call__ + +`ForegroundMask` +"""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ForegroundMask.png + :alt: example of ForegroundMask +.. autoclass:: ForegroundMask + :members: + :special-members: __call__ + IO ^^ @@ -1339,6 +1348,13 @@ Intensity (Dict) :members: :special-members: __call__ +`ForegroundMaskd` +""""""""""""""""" +.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/ForegroundMaskd.png + :alt: example of ForegroundMaskd +.. autoclass:: ForegroundMaskd + :members: + :special-members: __call__ IO (Dict) ^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 18459c1b7b..d4f09474de 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -81,6 +81,7 @@ from .intensity.array import ( AdjustContrast, DetectEnvelope, + ForegroundMask, GaussianSharpen, GaussianSmooth, GibbsNoise, @@ -117,6 +118,9 @@ AdjustContrastd, AdjustContrastD, AdjustContrastDict, + ForegroundMaskd, + ForegroundMaskD, + ForegroundMaskDict, GaussianSharpend, GaussianSharpenD, GaussianSharpenDict, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 06b8cfa108..43ed2df62a 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -16,7 +16,7 @@ from abc import abstractmethod from collections.abc import Iterable from functools import partial -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from warnings import warn import numpy as np @@ -29,17 +29,13 @@ from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where -from monai.utils import ( - convert_data_type, - convert_to_dst_type, - ensure_tuple, - ensure_tuple_rep, - ensure_tuple_size, - fall_back_tuple, -) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import TransformBackends -from monai.utils.type_conversion import convert_to_tensor, get_equivalent_dtype +from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple +from monai.utils.module import min_version, optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype + +skimage, _ = optional_import("skimage", "0.19.0", min_version) __all__ = [ "RandGaussianNoise", @@ -2161,3 +2157,100 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: img = IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img) return img + + +class ForegroundMask(Transform): + """ + Creates a binary mask that defines the foreground based on thresholds in RGB or HSV color space. + This transform receives an RGB (or grayscale) image where by default it is assumed that the foreground has + low values (dark) while the background has high values (white). Otherwise, set `invert` argument to `True`. + + Args: + threshold: an int or a float number that defines the threshold that values less than that are foreground. + It also can be a callable that receives each dimension of the image and calculate the threshold, + or a string that defines such callable from `skimage.filter.threshold_...`. For the list of available + threshold functions, please refer to https://scikit-image.org/docs/stable/api/skimage.filters.html + Moreover, a dictionary can be passed that defines such thresholds for each channel, like + {"R": 100, "G": "otsu", "B": skimage.filter.threshold_mean} + hsv_threshold: similar to threshold but HSV color space ("H", "S", and "V"). + Unlike RBG, in HSV, value greater than `hsv_threshold` are considered foreground. + invert: invert the intensity range of the input image, so that the dtype maximum is now the dtype minimum, + and vice-versa. + + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__( + self, + threshold: Union[Dict, Callable, str, float, int] = "otsu", + hsv_threshold: Optional[Union[Dict, Callable, str, float, int]] = None, + invert: bool = False, + ) -> None: + self.thresholds: Dict[str, Union[Callable, float]] = {} + if threshold is not None: + if isinstance(threshold, dict): + for mode, th in threshold.items(): + self._set_threshold(th, mode.upper()) + else: + self._set_threshold(threshold, "R") + self._set_threshold(threshold, "G") + self._set_threshold(threshold, "B") + if hsv_threshold is not None: + if isinstance(hsv_threshold, dict): + for mode, th in hsv_threshold.items(): + self._set_threshold(th, mode.upper()) + else: + self._set_threshold(hsv_threshold, "H") + self._set_threshold(hsv_threshold, "S") + self._set_threshold(hsv_threshold, "V") + + self.thresholds = {k: v for k, v in self.thresholds.items() if v is not None} + if self.thresholds.keys().isdisjoint(set("RGBHSV")): + raise ValueError( + f"Threshold for at least one channel of RGB or HSV needs to be set. {self.thresholds} is provided." + ) + + self.invert = invert + + def _set_threshold(self, threshold, mode): + if callable(threshold): + self.thresholds[mode] = threshold + elif isinstance(threshold, str): + self.thresholds[mode] = getattr(skimage.filters, "threshold_" + threshold.lower()) + elif isinstance(threshold, (float, int)): + self.thresholds[mode] = float(threshold) + else: + raise ValueError( + f"`threshold` should be either a callable, string, or float number, {type(threshold)} was given." + ) + + def _get_threshold(self, image, mode): + threshold = self.thresholds.get(mode) + if callable(threshold): + return threshold(image) + return threshold + + def __call__(self, image: NdarrayOrTensor): + img_rgb, *_ = convert_data_type(image, np.ndarray) + if self.invert: + img_rgb = skimage.util.invert(img_rgb) + foregrounds = [] + if not self.thresholds.keys().isdisjoint(set("RGB")): + rgb_foreground = np.zeros_like(img_rgb[:1]) + for img, mode in zip(img_rgb, "RGB"): + threshold = self._get_threshold(img, mode) + if threshold: + rgb_foreground = np.logical_or(rgb_foreground, img <= threshold) + foregrounds.append(rgb_foreground) + if not self.thresholds.keys().isdisjoint(set("HSV")): + img_hsv = skimage.color.rgb2hsv(img_rgb, channel_axis=0) + hsv_foreground = np.zeros_like(img_rgb[:1]) + for img, mode in zip(img_hsv, "HSV"): + threshold = self._get_threshold(img, mode) + if threshold: + hsv_foreground = np.logical_or(hsv_foreground, img > threshold) + foregrounds.append(hsv_foreground) + + mask = np.stack(foregrounds).all(axis=0) + return convert_to_dst_type(src=mask, dst=image)[0] diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 67dc73f93e..25cf261fe1 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -23,6 +23,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.transforms.intensity.array import ( AdjustContrast, + ForegroundMask, GaussianSharpen, GaussianSmooth, GibbsNoise, @@ -88,6 +89,7 @@ "RandCoarseDropoutd", "RandCoarseShuffled", "HistogramNormalized", + "ForegroundMaskd", "RandGaussianNoiseD", "RandGaussianNoiseDict", "ShiftIntensityD", @@ -146,6 +148,8 @@ "HistogramNormalizeDict", "RandKSpaceSpikeNoiseD", "RandKSpaceSpikeNoiseDict", + "ForegroundMaskD", + "ForegroundMaskDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -1654,6 +1658,52 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +class ForegroundMaskd(MapTransform): + """ + Creates a binary mask that defines the foreground based on thresholds in RGB or HSV color space. + This transform receives an RGB (or grayscale) image where by default it is assumed that the foreground has + low values (dark) while the background is white. + + Args: + keys: keys of the corresponding items to be transformed. + threshold: an int or a float number that defines the threshold that values less than that are foreground. + It also can be a callable that receives each dimension of the image and calculate the threshold, + or a string that defines such callable from `skimage.filter.threshold_...`. For the list of available + threshold functions, please refer to https://scikit-image.org/docs/stable/api/skimage.filters.html + Moreover, a dictionary can be passed that defines such thresholds for each channel, like + {"R": 100, "G": "otsu", "B": skimage.filter.threshold_mean} + hsv_threshold: similar to threshold but HSV color space ("H", "S", and "V"). + Unlike RBG, in HSV, value greater than `hsv_threshold` are considered foreground. + invert: invert the intensity range of the input image, so that the dtype maximum is now the dtype minimum, + and vice-versa. + new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of + key intact. By default not prefix is set and the corresponding array to the key will be replaced. + allow_missing_keys: do not raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + threshold: Union[Dict, Callable, str, float] = "otsu", + hsv_threshold: Optional[Union[Dict, Callable, str, float, int]] = None, + invert: bool = False, + new_key_prefix: Optional[str] = None, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.transform = ForegroundMask(threshold=threshold, hsv_threshold=hsv_threshold, invert=invert) + self.new_key_prefix = new_key_prefix + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + new_key = key if self.new_key_prefix is None else self.new_key_prefix + key + d[new_key] = self.transform(d[key]) + + return d + + RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd @@ -1683,3 +1733,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized RandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled +ForegroundMaskD = ForegroundMaskDict = ForegroundMaskd diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py index b096e1b93d..6165496599 100644 --- a/monai/transforms/utils_create_transform_ims.py +++ b/monai/transforms/utils_create_transform_ims.py @@ -85,6 +85,7 @@ ) from monai.transforms.intensity.array import ( AdjustContrast, + ForegroundMask, GaussianSharpen, GaussianSmooth, GibbsNoise, @@ -115,6 +116,7 @@ ) from monai.transforms.intensity.dictionary import ( AdjustContrastd, + ForegroundMaskd, GaussianSharpend, GaussianSmoothd, GibbsNoised, @@ -584,6 +586,8 @@ def create_transform_im( create_transform_im( MaskIntensityd, dict(keys=CommonKeys.IMAGE, mask_key=CommonKeys.IMAGE, select_fn=lambda x: x > 0.3), data ) + create_transform_im(ForegroundMask, dict(invert=True), data) + create_transform_im(ForegroundMaskd, dict(keys=CommonKeys.IMAGE, invert=True), data) create_transform_im(GaussianSmooth, dict(sigma=2), data) create_transform_im(GaussianSmoothd, dict(keys=CommonKeys.IMAGE, sigma=2), data) create_transform_im(RandGaussianSmooth, dict(prob=1.0, sigma_x=(1, 2)), data) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index cd8555d173..0d5d8bf92d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -33,6 +33,7 @@ MetricReduction, NumpyPadMode, PostFix, + ProbMapKeys, PytorchPadMode, SkipMode, TraceKeys, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 50b55560f9..9fb4e480f6 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -332,6 +332,18 @@ class BoxModeName(Enum): CCCWHD = "cccwhd" # [xcenter, ycenter, zcenter, xsize, ysize, zsize] +class ProbMapKeys(Enum): + """ + The keys to be used for generating the probability maps from patches + """ + + LOCATION = "mask_location" + SIZE = "mask_size" + COUNT = "num_patches" + PATH = "path" + PRE_PATH = "image" + + class GridPatchSort(Enum): """ The sorting method for the generated patches in `GridPatch` diff --git a/requirements-dev.txt b/requirements-dev.txt index ac8b3730d8..7bc06b8039 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ itk>=5.2 nibabel pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 tensorboard -scikit-image>=0.14.2 +scikit-image>=0.19.0 tqdm>=4.47.0 lmdb flake8>=3.8.1 diff --git a/tests/min_tests.py b/tests/min_tests.py index b52dc2a73d..cc35cf687f 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -53,6 +53,8 @@ def run_testsuit(): "test_ensure_channel_firstd", "test_fill_holes", "test_fill_holesd", + "test_foreground_mask", + "test_foreground_maskd", "test_global_mutual_information_loss", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", diff --git a/tests/test_foreground_mask.py b/tests/test_foreground_mask.py new file mode 100644 index 0000000000..c18e87fe53 --- /dev/null +++ b/tests/test_foreground_mask.py @@ -0,0 +1,96 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.intensity.array import ForegroundMask +from monai.utils import min_version, optional_import, set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +skimage, has_skimage = optional_import("skimage", "0.19.0", min_version) +set_determinism(1234) + +A = np.random.randint(64, 128, (3, 3, 2)).astype(np.uint8) +A3D = np.random.randint(64, 128, (3, 3, 2, 2)).astype(np.uint8) +B = np.ones_like(A[:1]) +B3D = np.ones_like(A3D[:1]) +MASK = np.pad(B, ((0, 0), (2, 2), (2, 2)), constant_values=0) +MASK3D = np.pad(B3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=0) +IMAGE1 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=255) +IMAGE3D = np.pad(A3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=255) +IMAGE2 = np.copy(IMAGE1) +IMAGE2[0] = 0 +IMAGE3 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=0) +TEST_CASE_0 = [{}, IMAGE1, MASK] +TEST_CASE_1 = [{"threshold": "otsu"}, IMAGE1, MASK] +TEST_CASE_2 = [{"threshold": "otsu"}, IMAGE2, MASK] +TEST_CASE_3 = [{"threshold": 140}, IMAGE1, MASK] +TEST_CASE_4 = [{"threshold": "otsu", "invert": True}, IMAGE3, MASK] +TEST_CASE_5 = [{"threshold": 0.5}, MASK, np.logical_not(MASK)] +TEST_CASE_6 = [{"threshold": 140}, IMAGE2, np.ones_like(MASK)] +TEST_CASE_7 = [{"threshold": {"R": "otsu", "G": "otsu", "B": "otsu"}}, IMAGE2, MASK] +TEST_CASE_8 = [{"threshold": {"R": 140, "G": "otsu", "B": "otsu"}}, IMAGE2, np.ones_like(MASK)] +TEST_CASE_9 = [{"threshold": {"R": 140, "G": skimage.filters.threshold_otsu, "B": "otsu"}}, IMAGE2, np.ones_like(MASK)] +TEST_CASE_10 = [{"threshold": skimage.filters.threshold_mean}, IMAGE1, MASK] +TEST_CASE_11 = [{"threshold": None, "hsv_threshold": "otsu"}, IMAGE1, np.ones_like(MASK)] +TEST_CASE_12 = [{"threshold": None, "hsv_threshold": {"S": "otsu"}}, IMAGE1, MASK] +TEST_CASE_13 = [{"threshold": 100, "invert": True}, IMAGE1, np.logical_not(MASK)] +TEST_CASE_14 = [{}, IMAGE3D, MASK3D] +TEST_CASE_15 = [{"hsv_threshold": {"S": 0.1}}, IMAGE3D, MASK3D] + +TEST_CASE_ERROR_1 = [{"threshold": None}, IMAGE1] +TEST_CASE_ERROR_2 = [{"threshold": {"K": 1}}, IMAGE1] + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, *TEST_CASE_0]) + TESTS.append([p, *TEST_CASE_1]) + TESTS.append([p, *TEST_CASE_2]) + TESTS.append([p, *TEST_CASE_3]) + TESTS.append([p, *TEST_CASE_4]) + TESTS.append([p, *TEST_CASE_5]) + TESTS.append([p, *TEST_CASE_6]) + TESTS.append([p, *TEST_CASE_7]) + TESTS.append([p, *TEST_CASE_8]) + TESTS.append([p, *TEST_CASE_9]) + TESTS.append([p, *TEST_CASE_10]) + TESTS.append([p, *TEST_CASE_11]) + TESTS.append([p, *TEST_CASE_12]) + TESTS.append([p, *TEST_CASE_13]) + TESTS.append([p, *TEST_CASE_14]) + TESTS.append([p, *TEST_CASE_15]) + +TESTS_ERROR = [] +for p in TEST_NDARRAYS: + TESTS_ERROR.append([p, *TEST_CASE_ERROR_1]) + TESTS_ERROR.append([p, *TEST_CASE_ERROR_2]) + + +@unittest.skipUnless(has_skimage, "Requires sci-kit image") +class TestForegroundMask(unittest.TestCase): + @parameterized.expand(TESTS) + def test_foreground_mask(self, in_type, arguments, image, mask): + input_image = in_type(image) + result = ForegroundMask(**arguments)(input_image) + assert_allclose(result, mask, type_test=False) + + @parameterized.expand(TESTS_ERROR) + def test_foreground_mask_error(self, in_type, arguments, image): + input_image = in_type(image) + with self.assertRaises(ValueError): + ForegroundMask(**arguments)(input_image) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_foreground_maskd.py b/tests/test_foreground_maskd.py new file mode 100644 index 0000000000..3c8aa08d7f --- /dev/null +++ b/tests/test_foreground_maskd.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.intensity.dictionary import ForegroundMaskd +from monai.utils import min_version, optional_import, set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +skimage, has_skimage = optional_import("skimage", "0.19.0", min_version) +set_determinism(1234) + +A = np.random.randint(64, 128, (3, 3, 2)).astype(np.uint8) +A3D = np.random.randint(64, 128, (3, 3, 2, 2)).astype(np.uint8) +B = np.ones_like(A[:1]) +B3D = np.ones_like(A3D[:1]) +MASK = np.pad(B, ((0, 0), (2, 2), (2, 2)), constant_values=0) +MASK3D = np.pad(B3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=0) +IMAGE1 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=255) +IMAGE3D = np.pad(A3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=255) +IMAGE2 = np.copy(IMAGE1) +IMAGE2[0] = 0 +IMAGE3 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=0) +TEST_CASE_0 = [{"keys": "image"}, {"image": IMAGE1}, MASK] +TEST_CASE_1 = [{"keys": "image", "threshold": "otsu"}, {"image": IMAGE1}, MASK] +TEST_CASE_2 = [{"keys": "image", "threshold": "otsu"}, {"image": IMAGE2}, MASK] +TEST_CASE_3 = [{"keys": "image", "threshold": 140}, {"image": IMAGE1}, MASK] +TEST_CASE_4 = [{"keys": "image", "threshold": "otsu", "invert": True}, {"image": IMAGE3}, MASK] +TEST_CASE_5 = [{"keys": "image", "threshold": 0.5}, {"image": MASK}, np.logical_not(MASK)] +TEST_CASE_6 = [{"keys": "image", "threshold": 140}, {"image": IMAGE2}, np.ones_like(MASK)] +TEST_CASE_7 = [{"keys": "image", "threshold": {"R": "otsu", "G": "otsu", "B": "otsu"}}, {"image": IMAGE2}, MASK] +TEST_CASE_8 = [ + {"keys": "image", "threshold": {"R": 140, "G": "otsu", "B": "otsu"}}, + {"image": IMAGE2}, + np.ones_like(MASK), +] +TEST_CASE_9 = [ + {"keys": "image", "threshold": {"R": 140, "G": skimage.filters.threshold_otsu, "B": "otsu"}}, + {"image": IMAGE2}, + np.ones_like(MASK), +] +TEST_CASE_10 = [{"keys": "image", "threshold": skimage.filters.threshold_mean}, {"image": IMAGE1}, MASK] +TEST_CASE_11 = [{"keys": "image", "threshold": None, "hsv_threshold": "otsu"}, {"image": IMAGE1}, np.ones_like(MASK)] +TEST_CASE_12 = [{"keys": "image", "threshold": None, "hsv_threshold": {"S": "otsu"}}, {"image": IMAGE1}, MASK] +TEST_CASE_13 = [{"keys": "image", "threshold": 100, "invert": True}, {"image": IMAGE1}, np.logical_not(MASK)] +TEST_CASE_14 = [{"keys": "image"}, {"image": IMAGE3D}, MASK3D] +TEST_CASE_15 = [{"keys": "image", "hsv_threshold": {"S": 0.1}}, {"image": IMAGE3D}, MASK3D] + +TEST_CASE_ERROR_1 = [{"keys": "image", "threshold": None}, {"image": IMAGE1}] +TEST_CASE_ERROR_2 = [{"keys": "image", "threshold": {"K": 1}}, {"image": IMAGE1}] + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, *TEST_CASE_0]) + TESTS.append([p, *TEST_CASE_1]) + TESTS.append([p, *TEST_CASE_2]) + TESTS.append([p, *TEST_CASE_3]) + TESTS.append([p, *TEST_CASE_4]) + TESTS.append([p, *TEST_CASE_5]) + TESTS.append([p, *TEST_CASE_6]) + TESTS.append([p, *TEST_CASE_7]) + TESTS.append([p, *TEST_CASE_8]) + TESTS.append([p, *TEST_CASE_9]) + TESTS.append([p, *TEST_CASE_10]) + TESTS.append([p, *TEST_CASE_11]) + TESTS.append([p, *TEST_CASE_12]) + TESTS.append([p, *TEST_CASE_13]) + TESTS.append([p, *TEST_CASE_14]) + TESTS.append([p, *TEST_CASE_15]) + +TESTS_ERROR = [] +for p in TEST_NDARRAYS: + TESTS_ERROR.append([p, *TEST_CASE_ERROR_1]) + TESTS_ERROR.append([p, *TEST_CASE_ERROR_2]) + + +@unittest.skipUnless(has_skimage, "Requires sci-kit image") +class TestForegroundMaskd(unittest.TestCase): + @parameterized.expand(TESTS) + def test_foreground_mask(self, in_type, arguments, data_dict, mask): + data_dict[arguments["keys"]] = in_type(data_dict[arguments["keys"]]) + result = ForegroundMaskd(**arguments)(data_dict)[arguments["keys"]] + assert_allclose(result, mask, type_test=False) + + @parameterized.expand(TESTS_ERROR) + def test_foreground_mask_error(self, in_type, arguments, data_dict): + data_dict[arguments["keys"]] = in_type(data_dict[arguments["keys"]]) + with self.assertRaises(ValueError): + ForegroundMaskd(**arguments)(data_dict)[arguments["keys"]] + + +if __name__ == "__main__": + unittest.main()