diff --git a/docs/source/data.rst b/docs/source/data.rst index 00ca9944cb..6158a564cf 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -343,3 +343,8 @@ Box center .. autofunction:: monai.data.box_utils.box_centers .. autofunction:: monai.data.box_utils.centers_in_boxes .. autofunction:: monai.data.box_utils.boxes_center_distance + +Spatial crop box +~~~~~~~~~~~~~~~~ +.. autofunction:: monai.data.box_utils.spatial_crop_boxes +.. autofunction:: monai.data.box_utils.clip_boxes_to_image diff --git a/monai/apps/detection/transforms/array.py b/monai/apps/detection/transforms/array.py index 901ed60615..ee7b0dac0a 100644 --- a/monai/apps/detection/transforms/array.py +++ b/monai/apps/detection/transforms/array.py @@ -236,7 +236,9 @@ def __init__(self, spatial_size: Union[Sequence[int], int], size_mode: str = "al self.size_mode = look_up_option(size_mode, ["all", "longest"]) self.spatial_size = spatial_size - def __call__(self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int]) -> NdarrayOrTensor: # type: ignore + def __call__( # type: ignore + self, boxes: NdarrayOrTensor, src_spatial_size: Union[Sequence[int], int] # type: ignore + ) -> NdarrayOrTensor: """ Args: boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index 37fc51c33d..8af21a31e9 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -806,9 +806,13 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso """ Compute the generalized intersection over union (GIoU) of two sets of boxes. The two inputs can have different shapes and the func return an NxM matrix, - (in contrary to ``box_pair_giou``, which requires the inputs to have the same + (in contrary to :func:`~monai.data.box_utils.box_pair_giou` , which requires the inputs to have the same shape and returns ``N`` values). + Args: + boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + Returns: GIoU, with size of (N,M) and same data type as ``boxes1`` @@ -860,11 +864,13 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor: """ Compute the generalized intersection over union (GIoU) of a pair of boxes. - The two inputs should have the same shape. + The two inputs should have the same shape and the func return an (N,) array, + (in contrary to :func:`~monai.data.box_utils.box_giou` , which does not require the inputs to have the same + shape and returns ``NxM`` matrix). Args: - boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode - boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be StandardMode + boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode`` Returns: paired GIoU, with size of (N,) and same data type as ``boxes1`` @@ -932,3 +938,77 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr # convert tensor back to numpy if needed giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1) return giou + + +def spatial_crop_boxes( + boxes: NdarrayOrTensor, + roi_start: Union[Sequence[int], NdarrayOrTensor], + roi_end: Union[Sequence[int], NdarrayOrTensor], + remove_empty: bool = True, +) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + """ + This function generate the new boxes when the corresponding image is cropped to the given ROI. + When ``remove_empty=True``, it makes sure the bounding boxes are within the new cropped image. + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + roi_start: voxel coordinates for start of the crop ROI, negative values allowed. + roi_end: voxel coordinates for end of the crop ROI, negative values allowed. + remove_empty: whether to remove the boxes that are actually empty + + Returns: + - cropped boxes, boxes[keep], does not share memory with original boxes + - ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``. + """ + + roi_start_torch, *_ = convert_data_type( + data=roi_start, output_type=torch.Tensor, dtype=torch.int16, wrap_sequence=True + ) + roi_end_torch, *_ = convert_to_dst_type(src=roi_end, dst=roi_start_torch, wrap_sequence=True) + roi_end_torch = torch.maximum(roi_end_torch, roi_start_torch) + + # convert numpy to tensor if needed + boxes_t, *_ = convert_data_type(deepcopy(boxes), torch.Tensor) + + # convert to float32 since torch.clamp_ does not support float16 + compute_dtype = torch.float32 + boxes_t = boxes_t.to(dtype=compute_dtype) + + # makes sure the bounding boxes are within the patch + spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=roi_end) + for axis in range(0, spatial_dims): + boxes_t[:, axis].clamp_(min=roi_start_torch[axis], max=roi_end_torch[axis] - TO_REMOVE) + boxes_t[:, axis + spatial_dims].clamp_(min=roi_start_torch[axis], max=roi_end_torch[axis] - TO_REMOVE) + boxes_t[:, axis] -= roi_start_torch[axis] + boxes_t[:, axis + spatial_dims] -= roi_start_torch[axis] + + # remove the boxes that are actually empty + if remove_empty: + keep_t = boxes_t[:, spatial_dims] >= boxes_t[:, 0] + 1 - TO_REMOVE + for axis in range(1, spatial_dims): + keep_t = keep_t & (boxes_t[:, axis + spatial_dims] >= boxes_t[:, axis] + 1 - TO_REMOVE) + boxes_t = boxes_t[keep_t] + + # convert tensor back to numpy if needed + boxes_keep, *_ = convert_to_dst_type(src=boxes_t, dst=boxes) + keep, *_ = convert_to_dst_type(src=keep_t, dst=boxes, dtype=keep_t.dtype) + + return boxes_keep, keep + + +def clip_boxes_to_image( + boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], NdarrayOrTensor], remove_empty: bool = True +) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + """ + This function clips the ``boxes`` to makes sure the bounding boxes are within the image. + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + spatial_size: The spatial size of the image where the boxes are attached. len(spatial_size) should be in [2, 3]. + remove_empty: whether to remove the boxes that are actually empty + + Returns: + updated box + """ + spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=spatial_size) + return spatial_crop_boxes(boxes, roi_start=[0] * spatial_dims, roi_end=spatial_size, remove_empty=remove_empty) diff --git a/tests/test_box_utils.py b/tests/test_box_utils.py index 891baa20c5..94731a2eb1 100644 --- a/tests/test_box_utils.py +++ b/tests/test_box_utils.py @@ -27,6 +27,7 @@ box_pair_giou, boxes_center_distance, centers_in_boxes, + clip_boxes_to_image, convert_box_mode, convert_box_to_standard_mode, ) @@ -142,6 +143,7 @@ def test_value(self, input_data, mode2, expected_box, expected_area): boxes1 = convert_data_type(input_data["boxes"], dtype=np.float32)[0] mode1 = input_data["mode"] half_bool = input_data["half"] + spatial_size = input_data["spatial_size"] # test float16 if half_bool: @@ -192,6 +194,15 @@ def test_value(self, input_data, mode2, expected_box, expected_area): center_dist, _, _ = boxes_center_distance(boxes1=result_standard[0:1, :], boxes2=result_standard[0:1, :]) assert_allclose(center_dist, np.array([[0.0]]), type_test=False) + # test clip_boxes_to_image + clipped_boxes, keep = clip_boxes_to_image(expected_box_standard, spatial_size, remove_empty=True) + assert_allclose( + expected_box_standard[keep, :], expected_box_standard[1:, :], type_test=True, device_test=True, atol=0.0 + ) + assert_allclose( + id(clipped_boxes) != id(expected_box_standard), True, type_test=False, device_test=False, atol=0.0 + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index ca5c4be33b..217c3479a4 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -267,7 +267,7 @@ def test_amp(self): im_conv = conv(im) with torch.cuda.amp.autocast(): im_conv2 = conv(im) - self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) + self.check(im_conv2, im_conv, ids=False, rtol=1e-2, atol=1e-2) def test_out(self): """Test when `out` is given as an argument."""