Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/transforms/croppad/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]:

d = deepcopy(data)
for key in d.keys():
transform_key = str(key) + InverseKeys.KEY_SUFFIX.value
transform_key = str(key) + InverseKeys.KEY_SUFFIX
if transform_key in d.keys():
transform = d[transform_key][-1]
if transform[InverseKeys.CLASS_NAME.value] == PadListDataCollate.__name__:
if transform[InverseKeys.CLASS_NAME] == PadListDataCollate.__name__:
d[key] = CenterSpatialCrop(transform["orig_size"])(d[key])
# remove transform
d[transform_key].pop()
Expand Down
22 changes: 11 additions & 11 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = transform[InverseKeys.ORIG_SIZE.value]
orig_size = transform[InverseKeys.ORIG_SIZE]
if self.padder.method == Method.SYMMETRIC:
current_size = d[key].shape[1:]
roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)]
Expand Down Expand Up @@ -202,15 +202,15 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value])
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
roi_start = np.array(self.padder.spatial_border)
# Need to convert single value to [min1,min2,...]
if roi_start.size == 1:
roi_start = np.full((len(orig_size)), roi_start)
# need to convert [min1,max1,min2,...] to [min1,min2,...]
elif roi_start.size == 2 * orig_size.size:
roi_start = roi_start[::2]
roi_end = np.array(transform[InverseKeys.ORIG_SIZE.value]) + roi_start
roi_end = np.array(transform[InverseKeys.ORIG_SIZE]) + roi_start

inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end)
# Apply inverse transform
Expand Down Expand Up @@ -268,7 +268,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value])
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
roi_start = np.floor((current_size - orig_size) / 2)
roi_end = orig_size + roi_start
Expand Down Expand Up @@ -323,7 +323,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = transform[InverseKeys.ORIG_SIZE.value]
orig_size = transform[InverseKeys.ORIG_SIZE]
pad_to_start = np.array(self.cropper.roi_start)
pad_to_end = orig_size - self.cropper.roi_end
# interleave mins and maxes
Expand Down Expand Up @@ -369,7 +369,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value])
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
pad_to_start = np.floor((orig_size - current_size) / 2).astype(int)
# in each direction, if original size is even and current size is odd, += 1
Expand Down Expand Up @@ -449,12 +449,12 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = transform[InverseKeys.ORIG_SIZE.value]
orig_size = transform[InverseKeys.ORIG_SIZE]
random_center = self.random_center
pad_to_start = np.empty((len(orig_size)), dtype=np.int32)
pad_to_end = np.empty((len(orig_size)), dtype=np.int32)
if random_center:
for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO.value]["slices"]):
for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO]["slices"]):
pad_to_start[i] = _slice[0]
pad_to_end[i] = orig_size[i] - _slice[1]
else:
Expand Down Expand Up @@ -594,8 +594,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value])
extra_info = transform[InverseKeys.EXTRA_INFO.value]
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
extra_info = transform[InverseKeys.EXTRA_INFO]
pad_to_start = np.array(extra_info["box_start"])
pad_to_end = orig_size - np.array(extra_info["box_end"])
# interleave mins and maxes
Expand Down Expand Up @@ -827,7 +827,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value])
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
# Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding.
# Instead, we first pad any smaller dimensions, and then we crop any larger dimensions.
Expand Down
20 changes: 10 additions & 10 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,43 +72,43 @@ def push_transform(
orig_size: Optional[Tuple] = None,
) -> None:
"""Append to list of applied transforms for that key."""
key_transform = str(key) + InverseKeys.KEY_SUFFIX.value
key_transform = str(key) + InverseKeys.KEY_SUFFIX
info = {
InverseKeys.CLASS_NAME.value: self.__class__.__name__,
InverseKeys.ID.value: id(self),
InverseKeys.ORIG_SIZE.value: orig_size or data[key].shape[1:],
InverseKeys.CLASS_NAME: self.__class__.__name__,
InverseKeys.ID: id(self),
InverseKeys.ORIG_SIZE: orig_size or data[key].shape[1:],
}
if extra_info is not None:
info[InverseKeys.EXTRA_INFO.value] = extra_info
info[InverseKeys.EXTRA_INFO] = extra_info
# If class is randomizable transform, store whether the transform was actually performed (based on `prob`)
if isinstance(self, RandomizableTransform):
info[InverseKeys.DO_TRANSFORM.value] = self._do_transform
info[InverseKeys.DO_TRANSFORM] = self._do_transform
# If this is the first, create list
if key_transform not in data:
data[key_transform] = []
data[key_transform].append(info)

def check_transforms_match(self, transform: dict) -> None:
"""Check transforms are of same instance."""
if transform[InverseKeys.ID.value] == id(self):
if transform[InverseKeys.ID] == id(self):
return
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
if (
torch.multiprocessing.get_start_method(allow_none=False) == "spawn"
and transform[InverseKeys.CLASS_NAME.value] == self.__class__.__name__
and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__
):
return
raise RuntimeError("Should inverse most recently applied invertible transform first")

def get_most_recent_transform(self, data: dict, key: Hashable) -> dict:
"""Get most recent transform."""
transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX.value][-1])
transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX][-1])
self.check_transforms_match(transform)
return transform

def pop_transform(self, data: dict, key: Hashable) -> None:
"""Remove most recent transform."""
data[str(key) + InverseKeys.KEY_SUFFIX.value].pop()
data[str(key) + InverseKeys.KEY_SUFFIX].pop()

def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]:
"""
Expand Down
46 changes: 23 additions & 23 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
+ "Please raise a github issue if you need this feature"
)
# Create inverse transform
meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]]
old_affine = np.array(transform[InverseKeys.EXTRA_INFO.value]["old_affine"])
meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_data_key"]]
old_affine = np.array(transform[InverseKeys.EXTRA_INFO]["old_affine"])
orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1]
inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal)
# Apply inverse
Expand Down Expand Up @@ -312,8 +312,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]]
orig_affine = transform[InverseKeys.EXTRA_INFO.value]["old_affine"]
meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_data_key"]]
orig_affine = transform[InverseKeys.EXTRA_INFO]["old_affine"]
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
inverse_transform = Orientation(
axcodes=orig_axcodes,
Expand Down Expand Up @@ -429,9 +429,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
if transform[InverseKeys.DO_TRANSFORM]:
# Create inverse transform
num_times_rotated = transform[InverseKeys.EXTRA_INFO.value]["rand_k"]
num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"]
num_times_to_rotate = 4 - num_times_rotated
inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes)
# Might need to convert to numpy
Expand Down Expand Up @@ -491,7 +491,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
d = deepcopy(dict(data))
for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners):
transform = self.get_most_recent_transform(d, key)
orig_size = transform[InverseKeys.ORIG_SIZE.value]
orig_size = transform[InverseKeys.ORIG_SIZE]
# Create inverse transform
inverse_transform = Resize(orig_size, mode, align_corners)
# Apply inverse transform
Expand Down Expand Up @@ -582,9 +582,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
transform = self.get_most_recent_transform(d, key)
orig_size = transform[InverseKeys.ORIG_SIZE.value]
orig_size = transform[InverseKeys.ORIG_SIZE]
# Create inverse transform
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"]
inv_affine = np.linalg.inv(fwd_affine)

affine_grid = AffineGrid(affine=inv_affine)
Expand Down Expand Up @@ -710,9 +710,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar

for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
transform = self.get_most_recent_transform(d, key)
orig_size = transform[InverseKeys.ORIG_SIZE.value]
orig_size = transform[InverseKeys.ORIG_SIZE]
# Create inverse transform
fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"]
fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"]
inv_affine = np.linalg.inv(fwd_affine)

affine_grid = AffineGrid(affine=inv_affine)
Expand Down Expand Up @@ -1048,7 +1048,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
if transform[InverseKeys.DO_TRANSFORM]:
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
Expand Down Expand Up @@ -1098,8 +1098,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO.value]["axis"])
if transform[InverseKeys.DO_TRANSFORM]:
flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"])
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"]
fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"]
inv_rot_mat = np.linalg.inv(fwd_rot_mat)

xform = AffineTransform(
Expand All @@ -1194,7 +1194,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
output = xform(
torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0),
torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)),
spatial_size=transform[InverseKeys.ORIG_SIZE.value],
spatial_size=transform[InverseKeys.ORIG_SIZE],
)
d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32)
# Remove the applied transform
Expand Down Expand Up @@ -1314,9 +1314,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
if transform[InverseKeys.DO_TRANSFORM]:
# Create inverse transform
fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"]
fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"]
inv_rot_mat = np.linalg.inv(fwd_rot_mat)

xform = AffineTransform(
Expand All @@ -1329,7 +1329,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
output = xform(
torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0),
torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)),
spatial_size=transform[InverseKeys.ORIG_SIZE.value],
spatial_size=transform[InverseKeys.ORIG_SIZE],
)
d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32)
# Remove the applied transform
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
align_corners=align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key])
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key])
# Remove the applied transform
self.pop_transform(d, key)

Expand Down Expand Up @@ -1513,9 +1513,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
if transform[InverseKeys.DO_TRANSFORM]:
# Create inverse transform
zoom = np.array(transform[InverseKeys.EXTRA_INFO.value]["zoom"])
zoom = np.array(transform[InverseKeys.EXTRA_INFO]["zoom"])
inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size)
# Apply inverse
d[key] = inverse_transform(
Expand All @@ -1525,7 +1525,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
align_corners=align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key])
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key])
# Remove the applied transform
self.pop_transform(d, key)

Expand Down
2 changes: 1 addition & 1 deletion monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class Method(Enum):
END = "end"


class InverseKeys(Enum):
class InverseKeys:
"""Extra meta data keys used for inverse transforms."""

CLASS_NAME = "class"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_decollate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check_match(self, in1, in2):
k1, k2 = k1.value, k2.value
self.check_match(k1, k2)
# Transform ids won't match for windows with multiprocessing, so don't check values
if k1 == InverseKeys.ID.value and sys.platform in ["darwin", "win32"]:
if k1 == InverseKeys.ID and sys.platform in ["darwin", "win32"]:
continue
self.check_match(v1, v2)
elif isinstance(in1, (list, tuple)):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def test_inverse_inferred_seg(self):
data = first(loader)
labels = data["label"].to(device)
segs = model(labels).detach().cpu()
label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value
label_transform_key = "label" + InverseKeys.KEY_SUFFIX
segs_dict = {"label": segs, label_transform_key: data[label_transform_key]}

segs_dict_decollated = decollate_batch(segs_dict)
Expand Down