From 99df9528e7b9489600ebf40cda232b9641e987ad Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 14 Jul 2021 17:09:55 +0800 Subject: [PATCH] [DLMED] enhance all the SpatialPad related Signed-off-by: Nic Ma --- monai/data/utils.py | 6 +++++- monai/transforms/croppad/array.py | 6 +++++- monai/transforms/croppad/batch.py | 7 ++++++- monai/transforms/croppad/dictionary.py | 5 ++++- tests/test_divisible_pad.py | 2 +- tests/test_divisible_padd.py | 2 +- tests/test_pad_collation.py | 5 ++++- 7 files changed, 26 insertions(+), 7 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index ed6a956108..4809bd9a49 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -420,6 +420,7 @@ def pad_list_data_collate( batch: Sequence, method: Union[Method, str] = Method.SYMMETRIC, mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + **np_kwargs, ): """ Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`. @@ -437,10 +438,13 @@ def pad_list_data_collate( batch: batch of data to pad-collate method: padding method (see :py:class:`monai.transforms.SpatialPad`) mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ from monai.transforms.croppad.batch import PadListDataCollate # needs to be here to avoid circular import - return PadListDataCollate(method, mode)(batch) + return PadListDataCollate(method=method, mode=mode, **np_kwargs)(batch) def no_collation(x): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 7377edef55..4ee5737bee 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -202,6 +202,7 @@ def __init__( self, k: Union[Sequence[int], int], mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + method: Union[Method, str] = Method.SYMMETRIC, **np_kwargs, ) -> None: """ @@ -213,6 +214,8 @@ def __init__( ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + method: {``"symmetric"``, ``"end"``} + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -220,6 +223,7 @@ def __init__( """ self.k = k self.mode: NumpyPadMode = NumpyPadMode(mode) + self.method: Method = Method(method) self.np_kwargs = np_kwargs def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray: @@ -235,7 +239,7 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N new_size = compute_divisible_spatial_size(spatial_shape=img.shape[1:], k=self.k) spatial_pad = SpatialPad( spatial_size=new_size, - method=Method.SYMMETRIC, + method=self.method, mode=mode or self.mode, **self.np_kwargs, ) diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 3ecabc387b..2c93c7b954 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -60,15 +60,20 @@ class PadListDataCollate(InvertibleTransform): Args: method: padding method (see :py:class:`monai.transforms.SpatialPad`) mode: padding mode (see :py:class:`monai.transforms.SpatialPad`) + np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. + more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html + """ def __init__( self, method: Union[Method, str] = Method.SYMMETRIC, mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT, + **np_kwargs, ) -> None: self.method = method self.mode = mode + self.np_kwargs = np_kwargs def __call__(self, batch: Any): """ @@ -99,7 +104,7 @@ def __call__(self, batch: Any): # Default params are central padding, padding with 0's # If input is dictionary, use the dictionary version so that the transformation is recorded - padder = SpatialPad(max_shape, self.method, self.mode) # type: ignore + padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.np_kwargs) transform = padder if not output_to_tensor else Compose([padder, ToTensor()]) for idx, batch_i in enumerate(batch): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 22adaa372a..346071aa3b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -254,6 +254,7 @@ def __init__( keys: KeysCollection, k: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT, + method: Union[Method, str] = Method.SYMMETRIC, allow_missing_keys: bool = False, **np_kwargs, ) -> None: @@ -269,6 +270,8 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html It also can be a sequence of string, each element corresponds to a key in ``keys``. + method: {``"symmetric"``, ``"end"``} + Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``. allow_missing_keys: don't raise exception if key is missing. np_kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension. more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -278,7 +281,7 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.mode = ensure_tuple_rep(mode, len(self.keys)) - self.padder = DivisiblePad(k=k, **np_kwargs) + self.padder = DivisiblePad(k=k, method=method, **np_kwargs) def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) diff --git a/tests/test_divisible_pad.py b/tests/test_divisible_pad.py index 1ca7e9f46a..e4415a2f22 100644 --- a/tests/test_divisible_pad.py +++ b/tests/test_divisible_pad.py @@ -25,7 +25,7 @@ # pad all dimensions to be divisible by 5 TEST_CASE_2 = [ - {"k": 5, "mode": "constant"}, + {"k": 5, "mode": "constant", "method": "end"}, np.zeros((3, 10, 5, 17)), np.zeros((3, 10, 5, 20)), ] diff --git a/tests/test_divisible_padd.py b/tests/test_divisible_padd.py index d894a9f42e..c834adac6d 100644 --- a/tests/test_divisible_padd.py +++ b/tests/test_divisible_padd.py @@ -23,7 +23,7 @@ ] TEST_CASE_2 = [ - {"keys": ["img"], "k": 7, "mode": "constant"}, + {"keys": ["img"], "k": 7, "mode": "constant", "method": "end"}, {"img": np.zeros((3, 8, 7))}, np.zeros((3, 14, 7)), ] diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 3835dc8895..a8c544558f 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -34,7 +34,10 @@ TESTS: List[Tuple] = [] -for pad_collate in [pad_list_data_collate, PadListDataCollate()]: +for pad_collate in [ + lambda x: pad_list_data_collate(batch=x, method="end", mode="constant", constant_values=1), + PadListDataCollate(method="end", mode="constant", constant_values=1), +]: TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))