Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions monai/data/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,25 +959,23 @@ def spatial_crop_boxes(
- ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``.
"""

roi_start_torch, *_ = convert_data_type(
data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True
)
roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True)
roi_end_torch = torch.maximum(roi_end_torch, roi_start_torch)

# convert numpy to tensor if needed
boxes_t, *_ = convert_data_type(deepcopy(boxes), torch.Tensor)

# convert to float32 since torch.clamp_ does not support float16
boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE)

roi_start_t = convert_to_dst_type(src=roi_start, dst=boxes_t, wrap_sequence=True)[0].to(torch.int16)
roi_end_t = convert_to_dst_type(src=roi_end, dst=boxes_t, wrap_sequence=True)[0].to(torch.int16)
roi_end_t = torch.maximum(roi_end_t, roi_start_t)

# makes sure the bounding boxes are within the patch
spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=roi_end)
for axis in range(0, spatial_dims):
boxes_t[:, axis].clamp_(min=roi_start_torch[axis], max=roi_end_torch[axis] - TO_REMOVE)
boxes_t[:, axis + spatial_dims].clamp_(min=roi_start_torch[axis], max=roi_end_torch[axis] - TO_REMOVE)
boxes_t[:, axis] -= roi_start_torch[axis]
boxes_t[:, axis + spatial_dims] -= roi_start_torch[axis]
boxes_t[:, axis].clamp_(min=roi_start_t[axis], max=roi_end_t[axis] - TO_REMOVE)
boxes_t[:, axis + spatial_dims].clamp_(min=roi_start_t[axis], max=roi_end_t[axis] - TO_REMOVE)
boxes_t[:, axis] -= roi_start_t[axis]
boxes_t[:, axis + spatial_dims] -= roi_start_t[axis]

# remove the boxes that are actually empty
if remove_empty:
Expand Down