diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 20d9458882..420f8cd9dd 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -478,8 +478,8 @@ def __call__(self, data): if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) - box_start = cropper.roi_start - box_end = cropper.roi_end + box_start = np.array([s.start for s in cropper.slices]) + box_end = np.array([s.stop for s in cropper.slices]) else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index 2d57ed9325..83bb5ebaa4 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -176,6 +176,19 @@ np.array([[[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]]), ] +CROP_TEST_CASE_2 = [ + { + "keys": ["image", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + "spatial_size": [2, 4, 4], + }, + DATA_1, + np.array([1, 1, 4, 4]), +] + ADD_INITIAL_POINT_TEST_CASE_1 = [ {"label": "label", "guidance": "guidance", "sids": "sids"}, DATA_1, @@ -360,6 +373,11 @@ def test_correct_results(self, arguments, input_data, expected_result): result = SpatialCropForegroundd(**arguments)(input_data) np.testing.assert_allclose(result["image"], expected_result) + @parameterized.expand([CROP_TEST_CASE_2]) + def test_correct_shape(self, arguments, input_data, expected_shape): + result = SpatialCropForegroundd(**arguments)(input_data) + np.testing.assert_equal(result["image"].shape, expected_shape) + @parameterized.expand([CROP_TEST_CASE_1]) def test_foreground_position(self, arguments, input_data, _): result = SpatialCropForegroundd(**arguments)(input_data)