diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index d4bab741eb..a37e9b9791 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -58,17 +58,20 @@ class SpatialPad(Transform): for additional details. Args: - c: the spatial size of output data after padding, if a dimension of the input + spatial_size: the spatial size of output data after padding, if a dimension of the input data size is bigger than the pad size, will not pad that dimension. - If its components have non-positive values, the corresponding size of input image will be used (no padding). - for example: if the spatial size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`, - the spatial size of output data will be [32, 30, 30]. + If its components have non-positive values, the corresponding size of input image will be used + (no padding). for example: if the spatial size of input data is [30, 30, 30] and + `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30]. method: {``"symmetric"``, ``"end"``} Pad image symmetric on every side or only pad at the end sides. Defaults to ``"symmetric"``. mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ def __init__( @@ -76,10 +79,12 @@ def __init__( spatial_size: Union[Sequence[int], int], method: Union[Method, str] = Method.SYMMETRIC, mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + **np_kwargs, ) -> None: self.spatial_size = spatial_size self.method: Method = Method(method) self.mode: NumpyPadMode = NumpyPadMode(mode) + self.np_kwargs = np_kwargs def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]: spatial_size = fall_back_tuple(self.spatial_size, data_shape) @@ -106,7 +111,9 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N if not np.asarray(all_pad_width).any(): # all zeros, skip padding return img - img = np.pad(img, all_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value) + + mode = self.mode.value if mode is None else NumpyPadMode(mode).value + img = np.pad(img, all_pad_width, mode=mode, **self.np_kwargs) return img @@ -130,13 +137,20 @@ class BorderPad(Transform): ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ def __init__( - self, spatial_border: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT + self, + spatial_border: Union[Sequence[int], int], + mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + **np_kwargs, ) -> None: self.spatial_border = spatial_border self.mode: NumpyPadMode = NumpyPadMode(mode) + self.np_kwargs = np_kwargs def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None): """ @@ -172,9 +186,8 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]." ) - return np.pad( - img, [(0, 0)] + data_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value - ) + mode = self.mode.value if mode is None else NumpyPadMode(mode).value + return np.pad(img, [(0, 0)] + data_pad_width, mode=mode, **self.np_kwargs) class DivisiblePad(Transform): @@ -182,7 +195,12 @@ class DivisiblePad(Transform): Pad the input data, so that the spatial sizes are divisible by `k`. """ - def __init__(self, k: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT) -> None: + def __init__( + self, + k: Union[Sequence[int], int], + mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + **np_kwargs, + ) -> None: """ Args: k: the target k for each spatial dimension. @@ -192,11 +210,14 @@ def __init__(self, k: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html See also :py:class:`monai.transforms.SpatialPad` """ self.k = k self.mode: NumpyPadMode = NumpyPadMode(mode) + self.np_kwargs = np_kwargs def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: """ @@ -209,7 +230,12 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) - spatial_pad = SpatialPad(spatial_size=new_size, method=Method.SYMMETRIC, mode=mode or self.mode) + spatial_pad = SpatialPad( + spatial_size=new_size, + method=Method.SYMMETRIC, + mode=mode or self.mode, + **self.np_kwargs, + ) return spatial_pad(img) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index ed01559ff5..d7f40233e5 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -111,6 +111,7 @@ def __init__( method: Union[Method, str] = Method.SYMMETRIC, mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, + **np_kwargs, ) -> None: """ Args: @@ -129,11 +130,13 @@ def __init__( See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = SpatialPad(spatial_size, method) + self.padder = SpatialPad(spatial_size, method, **np_kwargs) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -175,6 +178,7 @@ def __init__( spatial_border: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, + **np_kwargs, ) -> None: """ Args: @@ -197,11 +201,13 @@ def __init__( See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = BorderPad(spatial_border=spatial_border) + self.padder = BorderPad(spatial_border=spatial_border, **np_kwargs) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) @@ -247,6 +253,7 @@ def __init__( k: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, allow_missing_keys: bool = False, + **np_kwargs, ) -> None: """ Args: @@ -261,13 +268,15 @@ def __init__( See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html See also :py:class:`monai.transforms.SpatialPad` """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = DivisiblePad(k=k) + self.padder = DivisiblePad(k=k, **np_kwargs) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) diff --git a/tests/test_border_pad.py b/tests/test_border_pad.py index 14d93aae4e..b011601694 100644 --- a/tests/test_border_pad.py +++ b/tests/test_border_pad.py @@ -51,6 +51,12 @@ def test_pad_shape(self, input_param, input_data, expected_val): result = padder(input_data, mode=input_param["mode"]) self.assertAlmostEqual(result.shape, expected_val.shape) + def test_pad_kwargs(self): + padder = BorderPad(spatial_border=2, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) + result = padder(np.zeros((3, 8, 4))) + np.testing.assert_allclose(result[:, :2, 2:6], np.ones((3, 2, 4))) + np.testing.assert_allclose(result[:, :, :2], np.ones((3, 12, 2)) + 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index 27965b51d9..1ca7e9f46a 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -40,6 +40,12 @@ def test_pad_shape(self, input_param, input_data, expected_val): result = padder(input_data, mode=input_param["mode"]) self.assertAlmostEqual(result.shape, expected_val.shape) + def test_pad_kwargs(self): + padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2))) + result = padder(np.zeros((3, 8, 4))) + np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4))) + np.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_spatial_pad.py b/tests/test_spatial_pad.py index 4473a23770..93241610de 100644 --- a/tests/test_spatial_pad.py +++ b/tests/test_spatial_pad.py @@ -44,6 +44,14 @@ def test_pad_shape(self, input_param, input_data, expected_val): result = padder(input_data, mode=input_param["mode"]) np.testing.assert_allclose(result.shape, expected_val.shape) + def test_pad_kwargs(self): + padder = SpatialPad( + spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2)) + ) + result = padder(np.zeros((3, 8, 4))) + np.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4))) + np.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1) + if __name__ == "__main__": unittest.main()