From 2abb90b43d3ce60dacb525734cc52a1b53d9453a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 22 Mar 2021 09:28:07 +0000 Subject: [PATCH] remove .value from InverseKeys enum Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/croppad/batch.py | 4 +-- monai/transforms/croppad/dictionary.py | 22 ++++++------ monai/transforms/inverse.py | 20 +++++------ monai/transforms/spatial/dictionary.py | 46 +++++++++++++------------- monai/utils/enums.py | 2 +- tests/test_decollate.py | 2 +- tests/test_inverse.py | 2 +- 7 files changed, 49 insertions(+), 49 deletions(-) diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 7cbf39597c..37ff8618fa 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -119,10 +119,10 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]: d = deepcopy(data) for key in d.keys(): - transform_key = str(key) + InverseKeys.KEY_SUFFIX.value + transform_key = str(key) + InverseKeys.KEY_SUFFIX if transform_key in d.keys(): transform = d[transform_key][-1] - if transform[InverseKeys.CLASS_NAME.value] == PadListDataCollate.__name__: + if transform[InverseKeys.CLASS_NAME] == PadListDataCollate.__name__: d[key] = CenterSpatialCrop(transform["orig_size"])(d[key]) # remove transform d[transform_key].pop() diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 822db28467..c3523f3993 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -133,7 +133,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE.value] + orig_size = transform[InverseKeys.ORIG_SIZE] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] @@ -202,7 +202,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) + orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) roi_start = np.array(self.padder.spatial_border) # Need to convert single value to [min1,min2,...] if roi_start.size == 1: @@ -210,7 +210,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar # need to convert [min1,max1,min2,...] to [min1,min2,...] elif roi_start.size == 2 * orig_size.size: roi_start = roi_start[::2] - roi_end = np.array(transform[InverseKeys.ORIG_SIZE.value]) + roi_start + roi_end = np.array(transform[InverseKeys.ORIG_SIZE]) + roi_start inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) # Apply inverse transform @@ -268,7 +268,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) + orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) roi_start = np.floor((current_size - orig_size) / 2) roi_end = orig_size + roi_start @@ -323,7 +323,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE.value] + orig_size = transform[InverseKeys.ORIG_SIZE] pad_to_start = np.array(self.cropper.roi_start) pad_to_end = orig_size - self.cropper.roi_end # interleave mins and maxes @@ -369,7 +369,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) + orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) # in each direction, if original size is even and current size is odd, += 1 @@ -449,12 +449,12 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE.value] + orig_size = transform[InverseKeys.ORIG_SIZE] random_center = self.random_center pad_to_start = np.empty((len(orig_size)), dtype=np.int32) pad_to_end = np.empty((len(orig_size)), dtype=np.int32) if random_center: - for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO.value]["slices"]): + for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO]["slices"]): pad_to_start[i] = _slice[0] pad_to_end[i] = orig_size[i] - _slice[1] else: @@ -594,8 +594,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) - extra_info = transform[InverseKeys.EXTRA_INFO.value] + orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + extra_info = transform[InverseKeys.EXTRA_INFO] pad_to_start = np.array(extra_info["box_start"]) pad_to_end = orig_size - np.array(extra_info["box_end"]) # interleave mins and maxes @@ -827,7 +827,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE.value]) + orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) # Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding. # Instead, we first pad any smaller dimensions, and then we crop any larger dimensions. diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 9708f103e6..3e5b68e8e4 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -72,17 +72,17 @@ def push_transform( orig_size: Optional[Tuple] = None, ) -> None: """Append to list of applied transforms for that key.""" - key_transform = str(key) + InverseKeys.KEY_SUFFIX.value + key_transform = str(key) + InverseKeys.KEY_SUFFIX info = { - InverseKeys.CLASS_NAME.value: self.__class__.__name__, - InverseKeys.ID.value: id(self), - InverseKeys.ORIG_SIZE.value: orig_size or data[key].shape[1:], + InverseKeys.CLASS_NAME: self.__class__.__name__, + InverseKeys.ID: id(self), + InverseKeys.ORIG_SIZE: orig_size or data[key].shape[1:], } if extra_info is not None: - info[InverseKeys.EXTRA_INFO.value] = extra_info + info[InverseKeys.EXTRA_INFO] = extra_info # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) if isinstance(self, RandomizableTransform): - info[InverseKeys.DO_TRANSFORM.value] = self._do_transform + info[InverseKeys.DO_TRANSFORM] = self._do_transform # If this is the first, create list if key_transform not in data: data[key_transform] = [] @@ -90,25 +90,25 @@ def push_transform( def check_transforms_match(self, transform: dict) -> None: """Check transforms are of same instance.""" - if transform[InverseKeys.ID.value] == id(self): + if transform[InverseKeys.ID] == id(self): return # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) if ( torch.multiprocessing.get_start_method(allow_none=False) == "spawn" - and transform[InverseKeys.CLASS_NAME.value] == self.__class__.__name__ + and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__ ): return raise RuntimeError("Should inverse most recently applied invertible transform first") def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: """Get most recent transform.""" - transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX.value][-1]) + transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX][-1]) self.check_transforms_match(transform) return transform def pop_transform(self, data: dict, key: Hashable) -> None: """Remove most recent transform.""" - data[str(key) + InverseKeys.KEY_SUFFIX.value].pop() + data[str(key) + InverseKeys.KEY_SUFFIX].pop() def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: """ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 32327ec302..0d5b3436fd 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -225,8 +225,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar + "Please raise a github issue if you need this feature" ) # Create inverse transform - meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] - old_affine = np.array(transform[InverseKeys.EXTRA_INFO.value]["old_affine"]) + meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_data_key"]] + old_affine = np.array(transform[InverseKeys.EXTRA_INFO]["old_affine"]) orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse @@ -312,8 +312,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] - orig_affine = transform[InverseKeys.EXTRA_INFO.value]["old_affine"] + meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_data_key"]] + orig_affine = transform[InverseKeys.EXTRA_INFO]["old_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( axcodes=orig_axcodes, @@ -429,9 +429,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM.value]: + if transform[InverseKeys.DO_TRANSFORM]: # Create inverse transform - num_times_rotated = transform[InverseKeys.EXTRA_INFO.value]["rand_k"] + num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"] num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) # Might need to convert to numpy @@ -491,7 +491,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar d = deepcopy(dict(data)) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE.value] + orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform inverse_transform = Resize(orig_size, mode, align_corners) # Apply inverse transform @@ -582,9 +582,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE.value] + orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -710,9 +710,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE.value] + orig_size = transform[InverseKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO.value]["affine"] + fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -1048,7 +1048,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM.value]: + if transform[InverseKeys.DO_TRANSFORM]: # Might need to convert to numpy if isinstance(d[key], torch.Tensor): d[key] = torch.Tensor(d[key]).cpu().numpy() @@ -1098,8 +1098,8 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM.value]: - flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO.value]["axis"]) + if transform[InverseKeys.DO_TRANSFORM]: + flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"]) # Might need to convert to numpy if isinstance(d[key], torch.Tensor): d[key] = torch.Tensor(d[key]).cpu().numpy() @@ -1181,7 +1181,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"] + fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1194,7 +1194,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InverseKeys.ORIG_SIZE.value], + spatial_size=transform[InverseKeys.ORIG_SIZE], ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform @@ -1314,9 +1314,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM.value]: + if transform[InverseKeys.DO_TRANSFORM]: # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO.value]["rot_mat"] + fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( @@ -1329,7 +1329,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar output = xform( torch.as_tensor(np.ascontiguousarray(d[key]).astype(dtype)).unsqueeze(0), torch.as_tensor(np.ascontiguousarray(inv_rot_mat).astype(dtype)), - spatial_size=transform[InverseKeys.ORIG_SIZE.value], + spatial_size=transform[InverseKeys.ORIG_SIZE], ) d[key] = np.asarray(output.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # Remove the applied transform @@ -1410,7 +1410,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key]) + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key]) # Remove the applied transform self.pop_transform(d, key) @@ -1513,9 +1513,9 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar ): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM.value]: + if transform[InverseKeys.DO_TRANSFORM]: # Create inverse transform - zoom = np.array(transform[InverseKeys.EXTRA_INFO.value]["zoom"]) + zoom = np.array(transform[InverseKeys.EXTRA_INFO]["zoom"]) inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size) # Apply inverse d[key] = inverse_transform( @@ -1525,7 +1525,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar align_corners=align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key]) + d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE])(d[key]) # Remove the applied transform self.pop_transform(d, key) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index d661781616..63d65329af 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -217,7 +217,7 @@ class Method(Enum): END = "end" -class InverseKeys(Enum): +class InverseKeys: """Extra meta data keys used for inverse transforms.""" CLASS_NAME = "class" diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 4dc5a217a7..5b78bbbcf6 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -55,7 +55,7 @@ def check_match(self, in1, in2): k1, k2 = k1.value, k2.value self.check_match(k1, k2) # Transform ids won't match for windows with multiprocessing, so don't check values - if k1 == InverseKeys.ID.value and sys.platform in ["darwin", "win32"]: + if k1 == InverseKeys.ID and sys.platform in ["darwin", "win32"]: continue self.check_match(v1, v2) elif isinstance(in1, (list, tuple)): diff --git a/tests/test_inverse.py b/tests/test_inverse.py index f548b53f11..d54855d7c1 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -565,7 +565,7 @@ def test_inverse_inferred_seg(self): data = first(loader) labels = data["label"].to(device) segs = model(labels).detach().cpu() - label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value + label_transform_key = "label" + InverseKeys.KEY_SUFFIX segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict)