From 06ae905f8e9c07e89f4835b5289619c03145b929 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Mon, 3 Jan 2022 15:12:44 -0500 Subject: [PATCH 1/2] add box util in monai/data --- monai/data/box_utils.py | 379 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 379 insertions(+) create mode 100644 monai/data/box_utils.py diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py new file mode 100644 index 0000000000..67aabc0792 --- /dev/null +++ b/monai/data/box_utils.py @@ -0,0 +1,379 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Sequence, Union + +import numpy as np +import torch + +import point_utils + +SUPPORT_MODE = ["xxyy","xxyyzz","xyxy", "xyzxyz", "xywh", "xyzwhd"] +STANDARD_MODE = ["xxyy","xxyyzz"] # [2d_mode, 3d_mode] + +# TO_REMOVE = 0 if in 'xxyy','xxyyzz' mode, the bottom-right corner is not included in the box, +# i.e., when x_min=1, x_max=2, we have w = 1 +# if in 'xxyy','xxyyzz' mode, the bottom-right corner is included in the box, +# i.e., when x_min=1, x_max=2, we have w = 2 +TO_REMOVE = 0 # x_max-x_min = w -TO_REMOVE + +""" +The following variables share the same definition across the functions in this file. +Args: + bbox: Nx4 or Nx6 torch tensor + mode: choose from SUPPORT_MODE. If mode is not given, these funcs will assume mode is STANDARD_MODE + image_size: Length of 2 or 3. Data format is list, or np.ndarray, or tensor of int +""" + + +def check_support_mode(mode): + """ + Check if the mode is supported + """ + if mode not in SUPPORT_MODE: + raise ValueError("mode should be a string in {}.".format(SUPPORT_MODE)) + return + +def check_standard_mode(mode): + """ + Check if the mode is supported + """ + if mode not in STANDARD_MODE: + raise ValueError("Standard mode should be a string in {}.".format(STANDARD_MODE)) + return + +def convert_to_list(in_sequence: Union[Sequence, torch.Tensor, np.ndarray]) -> list: + """ + convert a torch.Tensor, or np array input to list + Args: + in_sequence: + Returns: in_sequence_list + + """ + in_sequence_list = deepcopy(in_sequence) + if torch.is_tensor(in_sequence): + in_sequence_list.cpu().detach().numpy().tolist() + elif isinstance(in_sequence, np.ndarray): + in_sequence_list.tolist() + elif not isinstance(in_sequence, list): + in_sequence_list = list(in_sequence_list) + return in_sequence_list + +def get_dimension( + bbox: torch.Tensor = None, image_size: Union[Sequence[int], torch.Tensor, np.ndarray] = None, mode: str = None +) -> int: + """ + Get spatial dimension for the giving setting. + Missing input is allowed. But at least one of the input value should be given. + Returns: spatial_dimension + """ + spatial_dims = set() + if image_size is not None: + spatial_dims.add(len(image_size)) + if mode is not None: + spatial_dims.add(len(mode) / 2) + if bbox is not None: + spatial_dims.add(int(bbox.shape[1] / 2)) + spatial_dims = list(spatial_dims) + if len(spatial_dims) == 0: + raise ValueError("At least one of bbox, image_size, and mode needs to be non-empty.") + elif len(spatial_dims) == 1: + if spatial_dims[0] not in [2, 3]: + raise ValueError("Images should have 2 or 3 dimensions, got {}".format(spatial_dims[0])) + return int(spatial_dims[0]) + else: + raise ValueError("The dimension of bbox, image_size, mode should match with each other.") + + +def get_standard_mode(spatial_dims: int) -> str: + """ + Get the mode name for the given spatial dimension + Args: + spatial_dims: 2 or 3 + + Returns: mode + + """ + if spatial_dims == 2: + return STANDARD_MODE[0] + elif spatial_dims == 3: + return STANDARD_MODE[1] + else: + ValueError("Images should have 2 or 3 dimensions, got {}".format(spatial_dims)) + +def point_interp(point1: Union[Sequence, torch.Tensor, np.ndarray], zoom: Union[Sequence[float], float]) -> Union[Sequence, torch.Tensor, np.ndarray]: + """ + Convert point position from one pixel/voxel size to another pixel/voxel size + Args: + point1: point coordinate on an image with pixel/voxel size of pix_size1 + zoom: The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + Returns: + point2: point coordinate on an image with pixel/voxel size of pix_size2 + """ + # make sure the spatial dimensions of the inputs match with each other + spatial_dims = len(point1) + if spatial_dims not in [2, 3]: + raise ValueError("Images should have 2 or 3 dimensions, got {}".format(spatial_dims)) + + # compute new point + point2 = deepcopy(point1) + _zoom = monai.utils.misc.ensure_tuple_rep(zoom, spatial_dims) + for axis in range(0,spatial_dims): + point2[axis] = point1[axis]*_zoom[axis] + return point2 + +def box_interp(bbox1: torch.Tensor, zoom: Union[Sequence[float], float], mode1: str = None) -> torch.Tensor: + """ + Interpolate bbox + Args: + zoom: The zoom factor along the spatial axes. + If a float, zoom is the same for each spatial axis. + If a sequence, zoom should contain one value for each spatial axis. + + Returns: + bbox2: returned bbox has the same mode as bbox1 + """ + if mode1 == None: + mode1 = get_standard_mode( int(bbox1.shape[1] / 2) ) + check_support_mode(mode1) + spatial_dims = get_dimension(bbox=bbox1, mode=mode1) + + mode_standard = get_standard_mode(spatial_dims) + bbox1_standard = box_convert_mode(bbox1=bbox1, mode1=mode1, mode2=mode_standard) + + corner_lt = point_utils.point_interp(bbox1_standard[:,::2],zoom) + corner_rb = point_utils.point_interp(bbox1_standard[:, 1::2],zoom) + + bbox2_standard_interp = deepcopy(bbox2_standard) + bbox2_standard_interp[:,::2] = corner_lt + bbox2_standard_interp[:,1::2] = corner_rb + + return box_convert_mode(bbox1=bbox2_standard_interp, mode1=mode_standard, mode2=mode1) + +def split_into_corners(bbox: torch.Tensor, mode: str): + """ + This internal function outputs the corner coordinates of the bbox + + Returns: + if 2D image, outputs (xmin, xmax, ymin, ymax) + if 3D images, outputs (xmin, xmax, ymin, ymax, zmin, zmax) + xmin for example, is a Nx1 tensor + + """ + check_support_mode(mode) + if mode in STANDARD_MODE: + return bbox.split(1, dim=-1) + elif mode == "xyzxyz": + xmin, ymin, zmin, xmax, ymax, zmax = bbox.split(1, dim=-1) + return ( + xmin, + xmax, + ymin, + ymax, + zmin, + zmax, + ) + elif mode == "xyxy": + xmin, ymin, xmax, ymax = bbox.split(1, dim=-1) + return ( + xmin, + xmax, + ymin, + ymax + ) + elif mode == "xyzwhd": + xmin, ymin, zmin, w, h, d = = bbox.split(1, dim=-1) + return ( + xmin, + xmin + (w - TO_REMOVE).clamp(min=0), + ymin, + ymin + (h - TO_REMOVE).clamp(min=0), + zmin, + zmin + (d - TO_REMOVE).clamp(min=0), + ) + elif mode == "xywh": + xmin, ymin, w, h = bbox.split(1, dim=-1) + return (xmin, xmin + (w - TO_REMOVE).clamp(min=0), ymin, ymin + (h - TO_REMOVE).clamp(min=0) ) + else: + raise RuntimeError("Should not be here") + + +def box_convert_mode(bbox1: torch.Tensor, mode1: str, mode2: str) -> torch.Tensor: + """ + This function converts the bbox1 in mode 1 to the mode2 + """ + # 1. check whether the bbox and the new mode is valid + check_support_mode(mode1) + check_support_mode(mode2) + + spatial_dims = get_dimension(bbox=bbox1, mode=mode1) + if len(mode1) != len(mode2): + raise ValueError("The dimension of the new mode should have the same spatial dimension as the old mode.") + + # 2. if mode not changed, return original boxlist + if mode1 == mode2: + return deepcopy(bbox1) + + # 3. convert mode for bbox + if mode2 in STANDARD_MODE: + corners = split_into_corners(deepcopy(bbox1), mode1) + return torch.cat(corners, dim=-1) + + if spatial_dims == 3: + xmin, xmax, ymin, ymax, zmin, zmax = split_into_corners(deepcopy(bbox1), mode1) + if mode2 == "xyzxyz": + bbox2 = torch.cat((xmin, ymin, zmin, xmax, ymax, zmax), dim=-1) + elif mode2 == "xyzwhd": + bbox2 = torch.cat( + (xmin, ymin, zmin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE, zmax - zmin + TO_REMOVE), dim=-1 + ) + else: + raise ValueError("We support only bbox mode in "+str(SUPPORT_MODE)+", got {}".format(mode2)) + elif spatial_dims == 2: + xmin, xmax, ymin, ymax = split_into_corners(deepcopy(bbox1), mode1) + if mode2 == "xyxy": + bbox2 = torch.cat((xmin, ymin, xmax, ymax), dim=-1) + elif mode2 == "xywh": + bbox2 = torch.cat((xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1) + else: + raise ValueError("We support only bbox mode in "+str(SUPPORT_MODE)+", got {}".format(mode2)) + else: + raise ValueError("Images should have 2 or 3 dimensions, got {}".format(spatial_dims)) + + return bbox2 + + +def box_convert_standard_mode(bbox: torch.Tensor, mode: str) -> torch.Tensor: + """ + This function convert the bbox in mode 1 to 'xyxy' or 'xyzxyz' + """ + check_support_mode(mode) + spatial_dims = get_dimension(bbox=bbox, mode=mode) + mode_standard = get_standard_mode(spatial_dims) + return box_convert_mode(bbox1=bbox, mode1=mode, mode2=mode_standard) + + +def box_area(bbox: torch.Tensor, mode: str = None) -> torch.tensor: + """ + This function computes the area of each box + Returns: + area: 1-D tensor + """ + + if mode == None: + mode = get_standard_mode( int(bbox.shape[1] / 2) ) + check_standard_mode(mode) + spatial_dims = get_dimension(bbox=bbox, mode=mode) + + area = bbox[:, 1] - bbox[:, 0] + TO_REMOVE + for axis in range(1, spatial_dims): + area = area * (bbox[:, 2*axis +1] - bbox[:, 2*axis] + TO_REMOVE) + + return area + + +def box_clip_to_image( + bbox: torch.Tensor, image_size: Union[Sequence[int], torch.Tensor, np.ndarray], mode: str = None, remove_empty: bool = True +) -> dict: + """ + This function makes sure the bounding boxes are within the image. + Args: + remove_empty: whether to remove the boxes that are actually empty + Returns: + updated box + """ + if mode == None: + mode = get_standard_mode( int(bbox.shape[1] / 2) ) + check_standard_mode(mode) + spatial_dims = get_dimension(bbox=bbox, image_size=image_size, mode=mode) + new_bbox = deepcopy(bbox) + if bbox.shape[0] == 0: + return deepcopy(bbox) + + # 1. convert to standard mode + mode_standard = get_standard_mode(spatial_dims) + new_bbox = box_convert_mode(bbox1=new_bbox, mode1=mode, mode2=mode_standard) + + # 2. makes sure the bounding boxes are within the image + for axis in range(0, spatial_dims): + new_bbox[:, 2*axis].clamp_(min=0, max=image_size[axis] - TO_REMOVE) + new_bbox[:, 2*axis + 1].clamp_(min=0, max=image_size[axis] - TO_REMOVE) + + # 3. remove the boxes that are actually empty + if remove_empty: + keep = (new_bbox[:, 1] > new_bbox[:, 0]) & (new_bbox[:, 3] > new_bbox[:, 2]) + if spatial_dims == 3: + keep = keep & (new_bbox[:, 5] > new_bbox[:, 4]) + new_bbox = new_bbox[keep] + + # 4. return updated boxlist + new_bbox = box_convert_mode(bbox1=new_bbox, mode1=mode_standard, mode2=mode) + + return new_bbox + + +def box_iou(bbox1: torch.Tensor, bbox2: torch.Tensor, mode1: str = None, mode2: str = None, gpubool: bool = True): + """ + Compute the intersection over union of two set of boxes. This function is not differentialable. + + IMPORTANT: Please run box_clip_to_image(bbox, image_size, mode, remove_empty=True) before computing IoU + + Implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py + with slight modifications. + + Arguments: + bbox1: Nx4 or Nx6, make sure they are non-empty + bbox2: Mx4 or Mx6, make sure they are non-empty + gpubool: whether to send the final IoU results to GPU + + Returns: + (tensor) iou, sized [N,M]. + + Reference: + https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py + """ + + if mode1 == None: + mode1 = get_standard_mode( int(bbox1.shape[1] / 2) ) + if mode2 == None: + mode2 = get_standard_mode( int(bbox2.shape[1] / 2) ) + check_standard_mode(mode1) + check_standard_mode(mode2) + spatial_dims = get_dimension(bbox=bbox1, mode=mode1) + + # we do computation on cpu + device = bbox1.device + + # compute area for the bbox + area1 = box_area(bbox=bbox1, mode=mode1).cpu() # Nx1 + area2 = box_area(bbox=bbox2, mode=mode2).cpu() # Mx1 + + # get the left top and right bottom points for the NxM combinations + lt = torch.max(bbox1[:, None, ::2], bbox2[:, ::2]) # [N,M,spatial_dims] left top + rb = torch.min( + bbox1_corner[:, None, 1::2], bbox2_corner[:, 1::2] + ) # [N,M,spatial_dims] right bottom + # compute size for the intersection region for the NxM combinations + wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,spatial_dims] + inter = wh[:, :, 0] # [N,M] + for axis in range(1, spatial_dims): + inter = inter * wh[:, :, axis] + + # compute IoU + iou = inter / (area1[:, None] + area2 - inter + torch.finfo(torch.float32).eps) # [N,M,spatial_dims] + + if gpubool: + iou = iou.to(device) # [N,M,spatial_dims] + + return iou From 60bf24dbe94d16d72943c233929dc47c0bab5782 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Mon, 3 Jan 2022 15:24:10 -0500 Subject: [PATCH 2/2] add box util in monai/data --- monai/data/box_utils.py | 86 +++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index 67aabc0792..daca40ef85 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -17,14 +17,14 @@ import point_utils -SUPPORT_MODE = ["xxyy","xxyyzz","xyxy", "xyzxyz", "xywh", "xyzwhd"] -STANDARD_MODE = ["xxyy","xxyyzz"] # [2d_mode, 3d_mode] +SUPPORT_MODE = ["xxyy", "xxyyzz", "xyxy", "xyzxyz", "xywh", "xyzwhd"] +STANDARD_MODE = ["xxyy", "xxyyzz"] # [2d_mode, 3d_mode] # TO_REMOVE = 0 if in 'xxyy','xxyyzz' mode, the bottom-right corner is not included in the box, # i.e., when x_min=1, x_max=2, we have w = 1 # if in 'xxyy','xxyyzz' mode, the bottom-right corner is included in the box, # i.e., when x_min=1, x_max=2, we have w = 2 -TO_REMOVE = 0 # x_max-x_min = w -TO_REMOVE +TO_REMOVE = 0 # x_max-x_min = w -TO_REMOVE """ The following variables share the same definition across the functions in this file. @@ -43,6 +43,7 @@ def check_support_mode(mode): raise ValueError("mode should be a string in {}.".format(SUPPORT_MODE)) return + def check_standard_mode(mode): """ Check if the mode is supported @@ -51,6 +52,7 @@ def check_standard_mode(mode): raise ValueError("Standard mode should be a string in {}.".format(STANDARD_MODE)) return + def convert_to_list(in_sequence: Union[Sequence, torch.Tensor, np.ndarray]) -> list: """ convert a torch.Tensor, or np array input to list @@ -68,11 +70,12 @@ def convert_to_list(in_sequence: Union[Sequence, torch.Tensor, np.ndarray]) -> l in_sequence_list = list(in_sequence_list) return in_sequence_list + def get_dimension( bbox: torch.Tensor = None, image_size: Union[Sequence[int], torch.Tensor, np.ndarray] = None, mode: str = None ) -> int: """ - Get spatial dimension for the giving setting. + Get spatial dimension for the giving setting. Missing input is allowed. But at least one of the input value should be given. Returns: spatial_dimension """ @@ -110,7 +113,10 @@ def get_standard_mode(spatial_dims: int) -> str: else: ValueError("Images should have 2 or 3 dimensions, got {}".format(spatial_dims)) -def point_interp(point1: Union[Sequence, torch.Tensor, np.ndarray], zoom: Union[Sequence[float], float]) -> Union[Sequence, torch.Tensor, np.ndarray]: + +def point_interp( + point1: Union[Sequence, torch.Tensor, np.ndarray], zoom: Union[Sequence[float], float] +) -> Union[Sequence, torch.Tensor, np.ndarray]: """ Convert point position from one pixel/voxel size to another pixel/voxel size Args: @@ -129,10 +135,11 @@ def point_interp(point1: Union[Sequence, torch.Tensor, np.ndarray], zoom: Union[ # compute new point point2 = deepcopy(point1) _zoom = monai.utils.misc.ensure_tuple_rep(zoom, spatial_dims) - for axis in range(0,spatial_dims): - point2[axis] = point1[axis]*_zoom[axis] + for axis in range(0, spatial_dims): + point2[axis] = point1[axis] * _zoom[axis] return point2 + def box_interp(bbox1: torch.Tensor, zoom: Union[Sequence[float], float], mode1: str = None) -> torch.Tensor: """ Interpolate bbox @@ -144,23 +151,24 @@ def box_interp(bbox1: torch.Tensor, zoom: Union[Sequence[float], float], mode1: Returns: bbox2: returned bbox has the same mode as bbox1 """ - if mode1 == None: - mode1 = get_standard_mode( int(bbox1.shape[1] / 2) ) + if mode1 is None: + mode1 = get_standard_mode(int(bbox1.shape[1] / 2)) check_support_mode(mode1) spatial_dims = get_dimension(bbox=bbox1, mode=mode1) mode_standard = get_standard_mode(spatial_dims) bbox1_standard = box_convert_mode(bbox1=bbox1, mode1=mode1, mode2=mode_standard) - corner_lt = point_utils.point_interp(bbox1_standard[:,::2],zoom) - corner_rb = point_utils.point_interp(bbox1_standard[:, 1::2],zoom) + corner_lt = point_utils.point_interp(bbox1_standard[:, ::2], zoom) + corner_rb = point_utils.point_interp(bbox1_standard[:, 1::2], zoom) bbox2_standard_interp = deepcopy(bbox2_standard) - bbox2_standard_interp[:,::2] = corner_lt - bbox2_standard_interp[:,1::2] = corner_rb + bbox2_standard_interp[:, ::2] = corner_lt + bbox2_standard_interp[:, 1::2] = corner_rb return box_convert_mode(bbox1=bbox2_standard_interp, mode1=mode_standard, mode2=mode1) + def split_into_corners(bbox: torch.Tensor, mode: str): """ This internal function outputs the corner coordinates of the bbox @@ -181,30 +189,25 @@ def split_into_corners(bbox: torch.Tensor, mode: str): xmax, ymin, ymax, - zmin, + zmin, zmax, ) elif mode == "xyxy": xmin, ymin, xmax, ymax = bbox.split(1, dim=-1) - return ( - xmin, - xmax, - ymin, - ymax - ) + return (xmin, xmax, ymin, ymax) elif mode == "xyzwhd": - xmin, ymin, zmin, w, h, d = = bbox.split(1, dim=-1) + xmin, ymin, zmin, w, h, d = bbox.split(1, dim=-1) return ( xmin, xmin + (w - TO_REMOVE).clamp(min=0), ymin, ymin + (h - TO_REMOVE).clamp(min=0), - zmin, + zmin, zmin + (d - TO_REMOVE).clamp(min=0), ) elif mode == "xywh": xmin, ymin, w, h = bbox.split(1, dim=-1) - return (xmin, xmin + (w - TO_REMOVE).clamp(min=0), ymin, ymin + (h - TO_REMOVE).clamp(min=0) ) + return (xmin, xmin + (w - TO_REMOVE).clamp(min=0), ymin, ymin + (h - TO_REMOVE).clamp(min=0)) else: raise RuntimeError("Should not be here") @@ -239,7 +242,7 @@ def box_convert_mode(bbox1: torch.Tensor, mode1: str, mode2: str) -> torch.Tenso (xmin, ymin, zmin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE, zmax - zmin + TO_REMOVE), dim=-1 ) else: - raise ValueError("We support only bbox mode in "+str(SUPPORT_MODE)+", got {}".format(mode2)) + raise ValueError("We support only bbox mode in " + str(SUPPORT_MODE) + ", got {}".format(mode2)) elif spatial_dims == 2: xmin, xmax, ymin, ymax = split_into_corners(deepcopy(bbox1), mode1) if mode2 == "xyxy": @@ -247,7 +250,7 @@ def box_convert_mode(bbox1: torch.Tensor, mode1: str, mode2: str) -> torch.Tenso elif mode2 == "xywh": bbox2 = torch.cat((xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1) else: - raise ValueError("We support only bbox mode in "+str(SUPPORT_MODE)+", got {}".format(mode2)) + raise ValueError("We support only bbox mode in " + str(SUPPORT_MODE) + ", got {}".format(mode2)) else: raise ValueError("Images should have 2 or 3 dimensions, got {}".format(spatial_dims)) @@ -270,21 +273,24 @@ def box_area(bbox: torch.Tensor, mode: str = None) -> torch.tensor: Returns: area: 1-D tensor """ - - if mode == None: - mode = get_standard_mode( int(bbox.shape[1] / 2) ) + + if mode is None: + mode = get_standard_mode(int(bbox.shape[1] / 2)) check_standard_mode(mode) spatial_dims = get_dimension(bbox=bbox, mode=mode) area = bbox[:, 1] - bbox[:, 0] + TO_REMOVE for axis in range(1, spatial_dims): - area = area * (bbox[:, 2*axis +1] - bbox[:, 2*axis] + TO_REMOVE) + area = area * (bbox[:, 2 * axis + 1] - bbox[:, 2 * axis] + TO_REMOVE) return area def box_clip_to_image( - bbox: torch.Tensor, image_size: Union[Sequence[int], torch.Tensor, np.ndarray], mode: str = None, remove_empty: bool = True + bbox: torch.Tensor, + image_size: Union[Sequence[int], torch.Tensor, np.ndarray], + mode: str = None, + remove_empty: bool = True, ) -> dict: """ This function makes sure the bounding boxes are within the image. @@ -293,8 +299,8 @@ def box_clip_to_image( Returns: updated box """ - if mode == None: - mode = get_standard_mode( int(bbox.shape[1] / 2) ) + if mode is None: + mode = get_standard_mode(int(bbox.shape[1] / 2)) check_standard_mode(mode) spatial_dims = get_dimension(bbox=bbox, image_size=image_size, mode=mode) new_bbox = deepcopy(bbox) @@ -307,8 +313,8 @@ def box_clip_to_image( # 2. makes sure the bounding boxes are within the image for axis in range(0, spatial_dims): - new_bbox[:, 2*axis].clamp_(min=0, max=image_size[axis] - TO_REMOVE) - new_bbox[:, 2*axis + 1].clamp_(min=0, max=image_size[axis] - TO_REMOVE) + new_bbox[:, 2 * axis].clamp_(min=0, max=image_size[axis] - TO_REMOVE) + new_bbox[:, 2 * axis + 1].clamp_(min=0, max=image_size[axis] - TO_REMOVE) # 3. remove the boxes that are actually empty if remove_empty: @@ -344,10 +350,10 @@ def box_iou(bbox1: torch.Tensor, bbox2: torch.Tensor, mode1: str = None, mode2: https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py """ - if mode1 == None: - mode1 = get_standard_mode( int(bbox1.shape[1] / 2) ) - if mode2 == None: - mode2 = get_standard_mode( int(bbox2.shape[1] / 2) ) + if mode1 is None: + mode1 = get_standard_mode(int(bbox1.shape[1] / 2)) + if mode2 is None: + mode2 = get_standard_mode(int(bbox2.shape[1] / 2)) check_standard_mode(mode1) check_standard_mode(mode2) spatial_dims = get_dimension(bbox=bbox1, mode=mode1) @@ -361,9 +367,7 @@ def box_iou(bbox1: torch.Tensor, bbox2: torch.Tensor, mode1: str = None, mode2: # get the left top and right bottom points for the NxM combinations lt = torch.max(bbox1[:, None, ::2], bbox2[:, ::2]) # [N,M,spatial_dims] left top - rb = torch.min( - bbox1_corner[:, None, 1::2], bbox2_corner[:, 1::2] - ) # [N,M,spatial_dims] right bottom + rb = torch.min(bbox1_corner[:, None, 1::2], bbox2_corner[:, 1::2]) # [N,M,spatial_dims] right bottom # compute size for the intersection region for the NxM combinations wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,spatial_dims] inter = wh[:, :, 0] # [N,M]