Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,28 +58,33 @@ class SpatialPad(Transform):
for additional details.

Args:
c: the spatial size of output data after padding, if a dimension of the input
spatial_size: the spatial size of output data after padding, if a dimension of the input
data size is bigger than the pad size, will not pad that dimension.
If its components have non-positive values, the corresponding size of input image will be used (no padding).
for example: if the spatial size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`,
the spatial size of output data will be [32, 30, 30].
If its components have non-positive values, the corresponding size of input image will be used
(no padding). for example: if the spatial size of input data is [30, 30, 30] and
`spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30].
method: {``"symmetric"``, ``"end"``}
Pad image symmetric on every side or only pad at the end sides. Defaults to ``"symmetric"``.
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

"""

def __init__(
self,
spatial_size: Union[Sequence[int], int],
method: Union[Method, str] = Method.SYMMETRIC,
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
**np_kwargs,
) -> None:
self.spatial_size = spatial_size
self.method: Method = Method(method)
self.mode: NumpyPadMode = NumpyPadMode(mode)
self.np_kwargs = np_kwargs

def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]:
spatial_size = fall_back_tuple(self.spatial_size, data_shape)
Expand All @@ -106,7 +111,9 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
if not np.asarray(all_pad_width).any():
# all zeros, skip padding
return img
img = np.pad(img, all_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value)

mode = self.mode.value if mode is None else NumpyPadMode(mode).value
img = np.pad(img, all_pad_width, mode=mode, **self.np_kwargs)
return img


Expand All @@ -130,13 +137,20 @@ class BorderPad(Transform):
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

"""

def __init__(
self, spatial_border: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT
self,
spatial_border: Union[Sequence[int], int],
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
**np_kwargs,
) -> None:
self.spatial_border = spatial_border
self.mode: NumpyPadMode = NumpyPadMode(mode)
self.np_kwargs = np_kwargs

def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None):
"""
Expand Down Expand Up @@ -172,17 +186,21 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]."
)

return np.pad(
img, [(0, 0)] + data_pad_width, mode=self.mode.value if mode is None else NumpyPadMode(mode).value
)
mode = self.mode.value if mode is None else NumpyPadMode(mode).value
return np.pad(img, [(0, 0)] + data_pad_width, mode=mode, **self.np_kwargs)


class DivisiblePad(Transform):
"""
Pad the input data, so that the spatial sizes are divisible by `k`.
"""

def __init__(self, k: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT) -> None:
def __init__(
self,
k: Union[Sequence[int], int],
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
**np_kwargs,
) -> None:
"""
Args:
k: the target k for each spatial dimension.
Expand All @@ -192,11 +210,14 @@ def __init__(self, k: Union[Sequence[int], int], mode: Union[NumpyPadMode, str]
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

See also :py:class:`monai.transforms.SpatialPad`
"""
self.k = k
self.mode: NumpyPadMode = NumpyPadMode(mode)
self.np_kwargs = np_kwargs

def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray:
"""
Expand All @@ -209,7 +230,12 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
"""
new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k)
spatial_pad = SpatialPad(spatial_size=new_size, method=Method.SYMMETRIC, mode=mode or self.mode)
spatial_pad = SpatialPad(
spatial_size=new_size,
method=Method.SYMMETRIC,
mode=mode or self.mode,
**self.np_kwargs,
)

return spatial_pad(img)

Expand Down
15 changes: 12 additions & 3 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
method: Union[Method, str] = Method.SYMMETRIC,
mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT,
allow_missing_keys: bool = False,
**np_kwargs,
) -> None:
"""
Args:
Expand All @@ -129,11 +130,13 @@ def __init__(
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
It also can be a sequence of string, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

"""
super().__init__(keys, allow_missing_keys)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = SpatialPad(spatial_size, method)
self.padder = SpatialPad(spatial_size, method, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
Expand Down Expand Up @@ -175,6 +178,7 @@ def __init__(
spatial_border: Union[Sequence[int], int],
mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT,
allow_missing_keys: bool = False,
**np_kwargs,
) -> None:
"""
Args:
Expand All @@ -197,11 +201,13 @@ def __init__(
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
It also can be a sequence of string, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

"""
super().__init__(keys, allow_missing_keys)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = BorderPad(spatial_border=spatial_border)
self.padder = BorderPad(spatial_border=spatial_border, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
Expand Down Expand Up @@ -247,6 +253,7 @@ def __init__(
k: Union[Sequence[int], int],
mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT,
allow_missing_keys: bool = False,
**np_kwargs,
) -> None:
"""
Args:
Expand All @@ -261,13 +268,15 @@ def __init__(
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
It also can be a sequence of string, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.
more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html

See also :py:class:`monai.transforms.SpatialPad`

"""
super().__init__(keys, allow_missing_keys)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = DivisiblePad(k=k)
self.padder = DivisiblePad(k=k, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_border_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def test_pad_shape(self, input_param, input_data, expected_val):
result = padder(input_data, mode=input_param["mode"])
self.assertAlmostEqual(result.shape, expected_val.shape)

def test_pad_kwargs(self):
padder = BorderPad(spatial_border=2, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))
result = padder(np.zeros((3, 8, 4)))
np.testing.assert_allclose(result[:, :2, 2:6], np.ones((3, 2, 4)))
np.testing.assert_allclose(result[:, :, :2], np.ones((3, 12, 2)) + 1)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions tests/test_divisible_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def test_pad_shape(self, input_param, input_data, expected_val):
result = padder(input_data, mode=input_param["mode"])
self.assertAlmostEqual(result.shape, expected_val.shape)

def test_pad_kwargs(self):
padder = DivisiblePad(k=5, mode="constant", constant_values=((0, 0), (1, 1), (2, 2)))
result = padder(np.zeros((3, 8, 4)))
np.testing.assert_allclose(result[:, :1, :4], np.ones((3, 1, 4)))
np.testing.assert_allclose(result[:, :, 4:5], np.ones((3, 10, 1)) + 1)


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions tests/test_spatial_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ def test_pad_shape(self, input_param, input_data, expected_val):
result = padder(input_data, mode=input_param["mode"])
np.testing.assert_allclose(result.shape, expected_val.shape)

def test_pad_kwargs(self):
padder = SpatialPad(
spatial_size=[15, 8], method="end", mode="constant", constant_values=((0, 0), (1, 1), (2, 2))
)
result = padder(np.zeros((3, 8, 4)))
np.testing.assert_allclose(result[:, 8:, :4], np.ones((3, 7, 4)))
np.testing.assert_allclose(result[:, :, 4:], np.ones((3, 15, 4)) + 1)


if __name__ == "__main__":
unittest.main()