Skip to content
Merged
5 changes: 5 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,8 @@ Box center
.. autofunction:: monai.data.box_utils.box_centers
.. autofunction:: monai.data.box_utils.centers_in_boxes
.. autofunction:: monai.data.box_utils.boxes_center_distance

Spatial crop box
~~~~~~~~~~~~~~~~
.. autofunction:: monai.data.box_utils.spatial_crop_boxes
.. autofunction:: monai.data.box_utils.clip_boxes_to_image
4 changes: 3 additions & 1 deletion monai/apps/detection/transforms/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def __init__(self, spatial_size: Union[Sequence[int], int], size_mode: str = "al
self.size_mode = look_up_option(size_mode, ["all", "longest"])
self.spatial_size = spatial_size

def __call__(self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int]) -> NdarrayOrTensor: # type: ignore
def __call__( # type: ignore
self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int] # type: ignore
) -> NdarrayOrTensor:
"""
Args:
boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
Expand Down
88 changes: 84 additions & 4 deletions monai/data/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,13 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
"""
Compute the generalized intersection over union (GIoU) of two sets of boxes.
The two inputs can have different shapes and the func return an NxM matrix,
(in contrary to ``box_pair_giou``, which requires the inputs to have the same
(in contrary to :func:`~monai.data.box_utils.box_pair_giou` , which requires the inputs to have the same
shape and returns ``N`` values).

Args:
boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

Returns:
GIoU, with size of (N,M) and same data type as ``boxes1``

Expand Down Expand Up @@ -860,11 +864,13 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Compute the generalized intersection over union (GIoU) of a pair of boxes.
The two inputs should have the same shape.
The two inputs should have the same shape and the func return an (N,) array,
(in contrary to :func:`~monai.data.box_utils.box_giou` , which does not require the inputs to have the same
shape and returns ``NxM`` matrix).

Args:
boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode
boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be StandardMode
boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``

Returns:
paired GIoU, with size of (N,) and same data type as ``boxes1``
Expand Down Expand Up @@ -932,3 +938,77 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
# convert tensor back to numpy if needed
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
return giou


def spatial_crop_boxes(
boxes: NdarrayOrTensor,
roi_start: Union[Sequence[int], NdarrayOrTensor],
roi_end: Union[Sequence[int], NdarrayOrTensor],
remove_empty: bool = True,
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
"""
This function generate the new boxes when the corresponding image is cropped to the given ROI.
When ``remove_empty=True``, it makes sure the bounding boxes are within the new cropped image.

Args:
boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
roi_start: voxel coordinates for start of the crop ROI, negative values allowed.
roi_end: voxel coordinates for end of the crop ROI, negative values allowed.
remove_empty: whether to remove the boxes that are actually empty

Returns:
- cropped boxes, boxes[keep], does not share memory with original boxes
Comment thread
Can-Zhao marked this conversation as resolved.
- ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``.
"""

roi_start_torch, *_ = convert_data_type(
data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True
)
roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True)
roi_end_torch = torch.maximum(roi_end_torch, roi_start_torch)

# convert numpy to tensor if needed
boxes_t, *_ = convert_data_type(deepcopy(boxes), torch.Tensor)

# convert to float32 since torch.clamp_ does not support float16
compute_dtype = torch.float32
boxes_t = boxes_t.to(dtype=compute_dtype)

# makes sure the bounding boxes are within the patch
spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=roi_end)
for axis in range(0, spatial_dims):
boxes_t[:, axis].clamp_(min=roi_start_torch[axis], max=roi_end_torch[axis] - TO_REMOVE)
boxes_t[:, axis + spatial_dims].clamp_(min=roi_start_torch[axis], max=roi_end_torch[axis] - TO_REMOVE)
boxes_t[:, axis] -= roi_start_torch[axis]
boxes_t[:, axis + spatial_dims] -= roi_start_torch[axis]

# remove the boxes that are actually empty
if remove_empty:
keep_t = boxes_t[:, spatial_dims] >= boxes_t[:, 0] + 1 - TO_REMOVE
for axis in range(1, spatial_dims):
keep_t = keep_t & (boxes_t[:, axis + spatial_dims] >= boxes_t[:, axis] + 1 - TO_REMOVE)
boxes_t = boxes_t[keep_t]

# convert tensor back to numpy if needed
boxes_keep, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)
keep, *_ = convert_to_dst_type(src=keep_t, dst=boxes, dtype=keep_t.dtype)

return boxes_keep, keep


def clip_boxes_to_image(
boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], NdarrayOrTensor], remove_empty: bool = True
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
"""
This function clips the ``boxes`` to makes sure the bounding boxes are within the image.

Args:
boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
spatial_size: The spatial size of the image where the boxes are attached. len(spatial_size) should be in [2, 3].
remove_empty: whether to remove the boxes that are actually empty

Returns:
updated box
"""
spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=spatial_size)
return spatial_crop_boxes(boxes, roi_start=[0] * spatial_dims, roi_end=spatial_size, remove_empty=remove_empty)
11 changes: 11 additions & 0 deletions tests/test_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
box_pair_giou,
boxes_center_distance,
centers_in_boxes,
clip_boxes_to_image,
convert_box_mode,
convert_box_to_standard_mode,
)
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_value(self, input_data, mode2, expected_box, expected_area):
boxes1 = convert_data_type(input_data["boxes"], dtype=np.float32)[0]
mode1 = input_data["mode"]
half_bool = input_data["half"]
spatial_size = input_data["spatial_size"]

# test float16
if half_bool:
Expand Down Expand Up @@ -192,6 +194,15 @@ def test_value(self, input_data, mode2, expected_box, expected_area):
center_dist, _, _ = boxes_center_distance(boxes1=result_standard[0:1, :], boxes2=result_standard[0:1, :])
assert_allclose(center_dist, np.array([[0.0]]), type_test=False)

# test clip_boxes_to_image
clipped_boxes, keep = clip_boxes_to_image(expected_box_standard, spatial_size, remove_empty=True)
assert_allclose(
expected_box_standard[keep, :], expected_box_standard[1:, :], type_test=True, device_test=True, atol=0.0
)
assert_allclose(
id(clipped_boxes) != id(expected_box_standard), True, type_test=False, device_test=False, atol=0.0
)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_amp(self):
im_conv = conv(im)
with torch.cuda.amp.autocast():
im_conv2 = conv(im)
self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3)
self.check(im_conv2, im_conv, ids=False, rtol=1e-2, atol=1e-2)

def test_out(self):
"""Test when `out` is given as an argument."""
Expand Down