Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
41e07dd
initial checkout
wyli Jun 22, 2022
7eb72e7
deepcopy condition
wyli Jun 22, 2022
8ffaffd
update types
wyli Jun 22, 2022
3197ef0
temp ignore type
wyli Jun 22, 2022
a521687
adds SpatialResample(d)
wyli Jun 22, 2022
f7123f5
adds resmapletomatch(d)
wyli Jun 22, 2022
6abd030
adds io transforms
wyli Jun 22, 2022
3c9774b
fixes tests
wyli Jun 22, 2022
b78c866
adds invertd
wyli Jun 22, 2022
758e63c
adds spacing(d)
wyli Jun 22, 2022
68331a2
adds orientation(d)
wyli Jun 22, 2022
8e390bf
adds flip(d)
wyli Jun 22, 2022
d39e0a0
adds resize(d)
wyli Jun 22, 2022
c8f9014
adds rotate(d)
wyli Jun 22, 2022
0852d02
adds rotate90(d)
wyli Jun 23, 2022
758adf6
fixes mypy
wyli Jun 23, 2022
d373d37
adds randrotate90(d)
wyli Jun 23, 2022
626b1e6
adds randrotate(d)
wyli Jun 23, 2022
e9e216e
adds randflip(d)
wyli Jun 23, 2022
22524a7
testing not tracking meta
wyli Jun 23, 2022
f41c633
adds randaxisflip(d)
wyli Jun 23, 2022
5b6f766
adds affine(d)
wyli Jun 23, 2022
130f537
adds randaffine(d)
wyli Jun 23, 2022
5ec1186
adds tests
wyli Jun 23, 2022
d30048b
adds affinegrid
wyli Jun 24, 2022
452c4df
adds randaffinegrid
wyli Jun 24, 2022
4b56e2d
adds randelastic2d(d)
wyli Jun 24, 2022
47fcd4a
adds randelastic3d(d)
wyli Jun 24, 2022
b87addf
adds griddistortion
wyli Jun 24, 2022
738031d
adds randgriddistortion(d)
wyli Jun 24, 2022
acb9672
fixes mypy
wyli Jun 24, 2022
f9d716b
Merge branch 'dev' into spatial-transform
wyli Jun 27, 2022
b19e345
adds resampler
wyli Jun 27, 2022
fbfc1c4
bc nonbreaking
wyli Jun 27, 2022
e2a4801
temp mute test_box_transform
wyli Jun 27, 2022
3150c76
fixes tests
wyli Jun 27, 2022
06c9eb7
Move metatensor support into dev branch (crop/pad) (#4548)
Nic-Ma Jun 29, 2022
efd99c4
Merge branch 'integration-metatensor' into spatial-transform
Nic-Ma Jun 29, 2022
6ac0761
Merge branch 'integration-metatensor' into spatial-transform
wyli Jun 29, 2022
e633789
adds zoom(d)
wyli Jun 29, 2022
93f2039
adds randzoom(d)
wyli Jun 29, 2022
5c1ede3
fixes mypy
wyli Jun 29, 2022
ee5658d
deepgrow dataset update
wyli Jun 29, 2022
a78511c
fixes missing __all__ item
wyli Jun 29, 2022
31c2103
update transforms/utility
wyli Jun 29, 2022
6786071
update activations
wyli Jun 29, 2022
10f8b9c
enable all testing types
wyli Jun 29, 2022
eb53388
review tests
wyli Jun 29, 2022
47a2424
Move metatensor support into dev branch (crop/pad) (#4548)
Nic-Ma Jun 29, 2022
4e6f211
Merge branch 'integration-metatensor' into spatial-transform
wyli Jun 29, 2022
653ece2
review tests
wyli Jun 29, 2022
574c840
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2022
24c5474
reivew tests
wyli Jun 29, 2022
3ca6102
review tests
wyli Jun 29, 2022
b18b3f5
fixes tests
wyli Jun 29, 2022
217beef
review/fix tests
wyli Jun 29, 2022
53d949e
fixes interp order
wyli Jun 29, 2022
de65f1b
fixes matshow3d tests
wyli Jun 29, 2022
e7fc32a
fixes tests
wyli Jun 29, 2022
2e6d922
fixes tests
wyli Jun 29, 2022
ca84709
fixes integration
wyli Jun 29, 2022
66e3e03
review tests
wyli Jun 29, 2022
381a852
simplify inverse
wyli Jun 30, 2022
de9677c
fixes tests
wyli Jun 30, 2022
4604427
slicing channels
wyli Jun 30, 2022
1d5a972
fixes typing
wyli Jun 30, 2022
1d0c129
Move metatensor support into dev branch (crop/pad) (#4548)
Nic-Ma Jun 29, 2022
e6a6204
Merge branch 'integration-metatensor' into spatial-transform
wyli Jun 30, 2022
b55cb8c
adds tests tensor
wyli Jun 30, 2022
3153d7c
more integration tests
wyli Jul 1, 2022
6bad902
fixes orig_size property
wyli Jul 1, 2022
62f57e8
fixes tests
wyli Jul 1, 2022
f80f33b
fixes premerge issues #4626 #4627
wyli Jul 2, 2022
abfbbb6
update based on comments
wyli Jul 2, 2022
4aa95c8
remove some type ignore
wyli Jul 2, 2022
2e81089
Merge remote-tracking branch 'upstream/dev' into integration-metatensor
wyli Jul 2, 2022
e075f07
Merge branch 'integration-metatensor' into spatial-transform
wyli Jul 2, 2022
35c1415
adds tests
wyli Jul 2, 2022
bcabb56
deprecated args in 0.9
wyli Jul 3, 2022
a9574c9
convert to tensor, no check isinstance
wyli Jul 3, 2022
bf5fb90
default orig_key
wyli Jul 3, 2022
cdbf046
review post
wyli Jul 3, 2022
841c4bf
review output types
wyli Jul 3, 2022
cdb1a32
fixes tests
wyli Jul 3, 2022
95cab4a
update based on comments
wyli Jul 4, 2022
94a2460
update based on comments
wyli Jul 5, 2022
b15d6bd
fixes docstring typos, type hints
wyli Jul 5, 2022
8ddbea9
fixes padding bug
wyli Jul 5, 2022
9709a4b
fixes sort_dicts
wyli Jul 5, 2022
b485f6d
Merge branch 'integration-metatensor' into spatial-transform
wyli Jul 5, 2022
74dbb25
[pre-commit.ci] pre-commit suggestions (#4631)
pre-commit-ci[bot] Jul 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion monai/apps/deepgrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

import numpy as np

from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd
from monai.transforms import AsChannelFirstd, Compose, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd
from monai.utils import GridSampleMode
from monai.utils.enums import PostFix


def create_dataset(
Expand Down Expand Up @@ -128,6 +129,8 @@ def _default_transforms(image_key, label_key, pixdim):
AsChannelFirstd(keys=keys),
Orientationd(keys=keys, axcodes="RAS"),
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
FromMetaTensord(keys=keys),
ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]),
]
)

Expand Down
18 changes: 5 additions & 13 deletions monai/apps/detection/transforms/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ def __init__(self, zoom: Union[Sequence[float], float], keep_size: bool = False,
self.keep_size = keep_size
self.kwargs = kwargs

def __call__(
self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int, None] = None
) -> NdarrayOrTensor: # type: ignore
def __call__(self, boxes: torch.Tensor, src_spatial_size: Union[Sequence[int], int, None] = None):
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -266,9 +264,7 @@ def __init__(self, spatial_size: Union[Sequence[int], int], size_mode: str = "al
self.size_mode = look_up_option(size_mode, ["all", "longest"])
self.spatial_size = spatial_size

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
def __call__(self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int]): # type: ignore
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -316,9 +312,7 @@ class FlipBox(Transform):
def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None:
self.spatial_axis = spatial_axis

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
def __call__(self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]): # type: ignore
"""
Args:
boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -489,7 +483,7 @@ def __init__(

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor]
) -> Tuple[NdarrayOrTensor, Union[Tuple, NdarrayOrTensor]]:
):
"""
Args:
boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down Expand Up @@ -535,9 +529,7 @@ class RotateBox90(Rotate90):
def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None:
super().__init__(k, spatial_axes)

def __call__( # type: ignore
self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
def __call__(self, boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int]): # type: ignore
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
Expand Down
8 changes: 4 additions & 4 deletions monai/apps/detection/transforms/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def apply_affine_to_boxes(boxes: NdarrayOrTensor, affine: NdarrayOrTensor) -> Nd
return boxes_affine


def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]) -> NdarrayOrTensor:
def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]):
"""
Zoom boxes

Expand Down Expand Up @@ -128,7 +128,7 @@ def zoom_boxes(boxes: NdarrayOrTensor, zoom: Union[Sequence[float], float]) -> N

def resize_boxes(
boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int], dst_spatial_size: Union[Sequence[int], int]
) -> NdarrayOrTensor:
):
"""
Resize boxes when the corresponding image is resized

Expand Down Expand Up @@ -262,7 +262,7 @@ def convert_box_to_mask(
boxes_only_mask = resizer(boxes_only_mask[None])[0] # type: ignore
else:
# generate a rect mask
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b]) # type: ignore
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
# apply to global mask
slicing = [b]
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
Expand Down Expand Up @@ -334,7 +334,7 @@ def select_labels(
Return:
selected labels, does not share memory with original labels.
"""
labels_tuple = ensure_tuple(labels, True) # type: ignore
labels_tuple = ensure_tuple(labels, True)

labels_select_list = []
keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0]
Expand Down
48 changes: 24 additions & 24 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
ZoomBox,
)
from monai.apps.detection.transforms.box_ops import convert_box_to_mask
from monai.config import KeysCollection
from monai.config import KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image
from monai.data.utils import orientation_ras_lps
Expand All @@ -43,7 +43,7 @@
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep
from monai.utils import InterpolateMode, NumpyPadMode, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix, TraceKeys
from monai.utils.type_conversion import convert_data_type

Expand Down Expand Up @@ -90,8 +90,6 @@
]

DEFAULT_POST_FIX = PostFix.meta()
InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str]
PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str]


class ConvertBoxModed(MapTransform, InvertibleTransform):
Expand Down Expand Up @@ -377,8 +375,8 @@ def __init__(
box_keys: KeysCollection,
box_ref_image_keys: KeysCollection,
zoom: Union[Sequence[float], float],
mode: InterpolateModeSequence = InterpolateMode.AREA,
padding_mode: PadModeSequence = NumpyPadMode.EDGE,
mode: SequenceStr = InterpolateMode.AREA,
padding_mode: SequenceStr = NumpyPadMode.EDGE,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
keep_size: bool = True,
allow_missing_keys: bool = False,
Expand All @@ -395,7 +393,7 @@ def __init__(
self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs)
self.keep_size = keep_size

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)

# zoom box
Expand Down Expand Up @@ -431,7 +429,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand All @@ -453,7 +451,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
align_corners=None if align_corners == TraceKeys.NONE else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[TraceKeys.EXTRA_INFO]["original_shape"], mode="edge")(d[key])
orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"]
d[key] = SpatialPad(orig_shape, mode="edge")(d[key])

# zoom boxes
if key_type == "box_key":
Expand Down Expand Up @@ -518,8 +517,8 @@ def __init__(
prob: float = 0.1,
min_zoom: Union[Sequence[float], float] = 0.9,
max_zoom: Union[Sequence[float], float] = 1.1,
mode: InterpolateModeSequence = InterpolateMode.AREA,
padding_mode: PadModeSequence = NumpyPadMode.EDGE,
mode: SequenceStr = InterpolateMode.AREA,
padding_mode: SequenceStr = NumpyPadMode.EDGE,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
keep_size: bool = True,
allow_missing_keys: bool = False,
Expand All @@ -544,7 +543,7 @@ def set_random_state(
self.rand_zoom.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
Expand Down Expand Up @@ -594,7 +593,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand All @@ -616,7 +615,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
align_corners=None if align_corners == TraceKeys.NONE else align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[TraceKeys.EXTRA_INFO]["original_shape"], mode="edge")(d[key])
orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"]
d[key] = SpatialPad(orig_shape, mode="edge")(d[key])

# zoom boxes
if key_type == "box_key":
Expand Down Expand Up @@ -661,7 +661,7 @@ def __init__(
self.flipper = Flip(spatial_axis=spatial_axis)
self.box_flipper = FlipBox(spatial_axis=self.flipper.spatial_axis)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)

for key in self.image_keys:
Expand All @@ -674,7 +674,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.push_transform(d, box_key, extra_info={"spatial_size": spatial_size, "type": "box_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -735,7 +735,7 @@ def set_random_state(
self.flipper.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
self.randomize(None)

Expand All @@ -751,7 +751,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.push_transform(d, box_key, extra_info={"spatial_size": spatial_size, "type": "box_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def randomize( # type: ignore
self.allow_smaller,
)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, torch.Tensor]]:
d = dict(data)
spatial_dims = len(d[self.image_keys[0]].shape) - 1
image_size = d[self.image_keys[0]].shape[1:]
Expand All @@ -1190,7 +1190,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab
raise ValueError("no available ROI centers to crop.")

# initialize returned list with shallow copy to preserve key ordering
results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(self.num_samples)]
results: List[Dict[Hashable, torch.Tensor]] = [dict(d) for _ in range(self.num_samples)]

# crop images and boxes for each center.
for i, center in enumerate(self.centers):
Expand Down Expand Up @@ -1255,7 +1255,7 @@ def __init__(
self.img_rotator = Rotate90(k, spatial_axes)
self.box_rotator = RotateBox90(k, spatial_axes)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]:
d = dict(data)
for key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):
spatial_size = list(d[box_ref_image_key].shape[1:])
Expand All @@ -1273,7 +1273,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
self.push_transform(d, key, extra_info={"type": "image_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def __init__(
super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys)
self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]:
self.randomize()
d = dict(data)

Expand Down Expand Up @@ -1359,7 +1359,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
self.push_transform(d, key, extra_info={"rand_k": self._rand_k, "type": "image_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))
if self._rand_k % 4 == 0:
return d
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def pad_images(
if max(pt_pad_width) == 0:
# if there is no need to pad
return input_images, [orig_size] * input_images.shape[0]
mode_: str = convert_pad_mode(dst=input_images, mode=mode).value
mode_: str = convert_pad_mode(dst=input_images, mode=mode)
return F.pad(input_images, pt_pad_width, mode=mode_, **kwargs), [orig_size] * input_images.shape[0]

# If input_images: List[Tensor])
Expand All @@ -151,7 +151,7 @@ def pad_images(
# Use `SpatialPad` to match sizes, padding in the end will not affect boxes
padder = SpatialPad(spatial_size=max_spatial_size, method="end", mode=mode, **kwargs)
for idx, img in enumerate(input_images):
images[idx, ...] = padder(img) # type: ignore
images[idx, ...] = padder(img)

return images, [list(ss) for ss in image_sizes]

Expand Down
23 changes: 10 additions & 13 deletions monai/apps/nuclick/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@

import math
import random
from enum import Enum
from typing import Any, Tuple, Union

import numpy as np

from monai.config import KeysCollection
from monai.transforms import MapTransform, Randomizable, SpatialPad
from monai.utils import optional_import
from monai.utils import StrEnum, optional_import

measure, _ = optional_import("skimage.measure")
morphology, _ = optional_import("skimage.morphology")


class NuclickKeys(Enum):
class NuclickKeys(StrEnum):
"""
Keys for nuclick transforms.
"""
Expand Down Expand Up @@ -83,7 +82,7 @@ class ExtractPatchd(MapTransform):
def __init__(
self,
keys: KeysCollection,
centroid_key: str = NuclickKeys.CENTROID.value,
centroid_key: str = NuclickKeys.CENTROID,
patch_size: Union[Tuple[int, int], int] = 128,
allow_missing_keys: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -138,9 +137,9 @@ class SplitLabeld(MapTransform):
def __init__(
self,
keys: KeysCollection,
# label: str = NuclickKeys.LABEL.value,
others: str = NuclickKeys.OTHERS.value,
mask_value: str = NuclickKeys.MASK_VALUE.value,
# label: str = NuclickKeys.LABEL,
others: str = NuclickKeys.OTHERS,
mask_value: str = NuclickKeys.MASK_VALUE,
min_area: int = 5,
):

Expand Down Expand Up @@ -268,9 +267,9 @@ class AddPointGuidanceSignald(Randomizable, MapTransform):

def __init__(
self,
image: str = NuclickKeys.IMAGE.value,
label: str = NuclickKeys.LABEL.value,
others: str = NuclickKeys.OTHERS.value,
image: str = NuclickKeys.IMAGE,
label: str = NuclickKeys.LABEL,
others: str = NuclickKeys.OTHERS,
drop_rate: float = 0.5,
jitter_range: int = 3,
):
Expand Down Expand Up @@ -338,9 +337,7 @@ class AddClickSignalsd(MapTransform):
bb_size: single integer size, defines a bounding box like (bb_size, bb_size)
"""

def __init__(
self, image: str = NuclickKeys.IMAGE.value, foreground: str = NuclickKeys.FOREGROUND.value, bb_size: int = 128
):
def __init__(self, image: str = NuclickKeys.IMAGE, foreground: str = NuclickKeys.FOREGROUND, bb_size: int = 128):
self.image = image
self.foreground = foreground
self.bb_size = bb_size
Expand Down
Loading