From 5c78165e264d9324829770b236b5d53baaf600d7 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Wed, 25 May 2022 18:36:43 -0400 Subject: [PATCH 01/13] add ATSS matcher Signed-off-by: Can Zhao --- .../detection/networks/utils/ATSS_matcher.py | 296 ++++++++++++++++++ .../apps/detection/networks/utils/__init__.py | 10 + tests/test_atss_box_matcher.py | 46 +++ 3 files changed, 352 insertions(+) create mode 100644 monai/apps/detection/networks/utils/ATSS_matcher.py create mode 100644 monai/apps/detection/networks/utils/__init__.py create mode 100644 tests/test_atss_box_matcher.py diff --git a/monai/apps/detection/networks/utils/ATSS_matcher.py b/monai/apps/detection/networks/utils/ATSS_matcher.py new file mode 100644 index 0000000000..2f876ace08 --- /dev/null +++ b/monai/apps/detection/networks/utils/ATSS_matcher.py @@ -0,0 +1,296 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ========================================================================= +# Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py +# which has the following license... +# https://github.com/MIC-DKFZ/nnDetection/blob/main/LICENSE +# +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ========================================================================= +# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py +# which has the following license... +# https://github.com/pytorch/vision/blob/main/LICENSE +# +# BSD 3-Clause License + +# Copyright (c) Soumith Chintala 2016, +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +The functions in this script are the almost the same with nnDetection, +https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py +which is modified from torchvision. + +These are the changes compared with nndetection: +- comments and docstrings; +- reformat; +- add a debug option to ATSSMatcher to help the users to tune parameters; +- add a corner case return in ATSSMatcher.compute_matches; +- add support for float16 cpu +""" + +import logging +from abc import ABC +from typing import Callable, Sequence, Tuple, TypeVar + +import torch +from torch import Tensor + +from monai.data.box_utils import COMPUTE_DTYPE, box_iou, boxes_center_distance, centers_in_boxes +from monai.utils.type_conversion import convert_to_tensor + +INF = 100 # not really inf but here it is sufficient + + +class Matcher(ABC): + BELOW_LOW_THRESHOLD: int = -1 + BETWEEN_THRESHOLDS: int = -2 + + def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou): # type: ignore + """ + Matches boxes and anchors to each other + + Args: + similarity_fn: function for similarity computation between + boxes and anchors + """ + self.similarity_fn = similarity_fn + + def __call__( + self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute matches for a single image + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``. + num_anchors_per_level: number of anchors per feature pyramid level + num_anchors_per_loc: number of anchors per position + + Returns: + - matrix which contains the similarity from each boxes + to each anchor [N, M] + - vector which contains the matched box index for all + anchors (if background `BELOW_LOW_THRESHOLD` is used + and if it should be ignored `BETWEEN_THRESHOLDS` is used) + [M] + + Note: + ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`, + also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D + and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D. + """ + if boxes.numel() == 0: + # no ground truth + num_anchors = anchors.shape[0] + match_quality_matrix = torch.tensor([]).to(anchors) + matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD) + return match_quality_matrix, matches + else: + # at least one ground truth + return self.compute_matches( + boxes=boxes, + anchors=anchors, + num_anchors_per_level=num_anchors_per_level, + num_anchors_per_loc=num_anchors_per_loc, + ) + + def compute_matches( + self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute matches + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``. + num_anchors_per_level: number of anchors per feature pyramid level + num_anchors_per_loc: number of anchors per position + + Returns: + - matrix which contains the similarity from each boxes + to each anchor [N, M] + - vector which contains the matched box index for all + anchors (if background `BELOW_LOW_THRESHOLD` is used + and if it should be ignored `BETWEEN_THRESHOLDS` is used) + [M] + """ + raise NotImplementedError + + +class ATSSMatcher(Matcher): + def __init__( + self, + num_candidates: int = 4, + similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou, # type: ignore + center_in_gt: bool = True, + debug=False, + ): + """ + Compute matching based on ATSS + https://arxiv.org/abs/1912.02424 + `Bridging the Gap Between Anchor-based and Anchor-free Detection + via Adaptive Training Sample Selection` + + Args: + num_candidates: number of positions to select candidates from. + Smaller value will result in a higher matcher threshold and less matched candidates. + similarity_fn: function for similarity computation between + boxes and anchors + center_in_gt: If False (default), matched anchor center points do not need + to lie withing the ground truth box. Recommand False for small objects. + If True, will result in a strict matcher and less matched candidates. + debug: if True, will print the matcher threshold in order to + tune ``num_candidates`` and ``center_in_gt``. + """ + super().__init__(similarity_fn=similarity_fn) + self.num_candidates = num_candidates + self.min_dist = 0.01 + self.center_in_gt = center_in_gt + self.debug = debug + logging.info( + f"Running ATSS Matching with num_candidates={self.num_candidates} " f"and center_in_gt {self.center_in_gt}." + ) + + def compute_matches( + self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute matches according to ATTS for a single image + Adapted from + (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss + /loss.py#L180-L184) + + Args: + boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` + anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``. + num_anchors_per_level: number of anchors per feature pyramid level + num_anchors_per_loc: number of anchors per position + + Returns: + Tensor: matrix which contains the similarity from each boxes + to each anchor [N, M] + Tensor: vector which contains the matched box index for all + anchors (if background `BELOW_LOW_THRESHOLD` is used + and if it should be ignored `BETWEEN_THRESHOLDS` is used) + [M] + + Note: + ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`, + also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D + and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D. + """ + num_gt = boxes.shape[0] + num_anchors = anchors.shape[0] + + distances_, _, anchors_center = boxes_center_distance(boxes, anchors) # num_boxes x anchors + distances = convert_to_tensor(distances_) + + # select candidates based on center distance + candidate_idx_list = [] + start_idx = 0 + for _, apl in enumerate(num_anchors_per_level): + end_idx = start_idx + apl * num_anchors_per_loc + + # topk: total number of candidates per position + topk = min(self.num_candidates * num_anchors_per_loc, apl) + # torch.topk() does not support float16 cpu, need conversion to float32 or float64 + _, idx = distances[:, start_idx:end_idx].to(COMPUTE_DTYPE).topk(topk, dim=1, largest=False) + # idx: shape [num_boxes x topk] + candidate_idx_list.append(idx + start_idx) + + start_idx = end_idx + # [num_boxes x num_candidates] (index of candidate anchors) + candidate_idx = torch.cat(candidate_idx_list, dim=1) + + match_quality_matrix = self.similarity_fn(boxes, anchors) # [num_boxes x anchors] + candidate_ious = match_quality_matrix.gather(1, candidate_idx) # [num_boxes, n_candidates] + + # corner case, n_candidates<=1 will make iou_std_per_gt NaN + if candidate_idx.shape[1] <= 1: + matches = -1 * torch.ones((num_anchors,), dtype=torch.long, device=boxes.device) + matches[candidate_idx] = 0 + return match_quality_matrix, matches + + # compute adaptive iou threshold + iou_mean_per_gt = candidate_ious.mean(dim=1) # [num_boxes] + iou_std_per_gt = candidate_ious.std(dim=1) # [num_boxes] + iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt # [num_boxes] + is_pos = candidate_ious >= iou_thresh_per_gt[:, None] # [num_boxes x n_candidates] + if self.debug: + print(f"Anchor matcher threshold: {iou_thresh_per_gt}") + + if self.center_in_gt: # can discard all candidates in case of very small objects :/ + # center point of selected anchors needs to lie within the ground truth + boxes_idx = ( + torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None] + .expand_as(candidate_idx) + .contiguous() + ) # [num_boxes x n_candidates] + is_in_gt_ = centers_in_boxes( + anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist + ) + is_in_gt = convert_to_tensor(is_in_gt_) + is_pos = is_pos & is_in_gt.view_as(is_pos) # [num_boxes x n_candidates] + + # in case on anchor is assigned to multiple boxes, use box with highest IoU + # TODO: think about a better way to do this + for ng in range(num_gt): + candidate_idx[ng, :] += ng * num_anchors + ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1) + index = candidate_idx.view(-1)[is_pos.view(-1)] + ious_inf[index] = match_quality_matrix.view(-1)[index] + ious_inf = ious_inf.view_as(match_quality_matrix) + + matched_vals, matches = ious_inf.max(dim=0) + matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD + # print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}") + return match_quality_matrix, matches + + +MatcherType = TypeVar("MatcherType", bound=Matcher) diff --git a/monai/apps/detection/networks/utils/__init__.py b/monai/apps/detection/networks/utils/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/detection/networks/utils/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_atss_box_matcher.py b/tests/test_atss_box_matcher.py new file mode 100644 index 0000000000..43745e3cc2 --- /dev/null +++ b/tests/test_atss_box_matcher.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +import torch +from parameterized import parameterized + +from monai.apps.detection.networks.utils.ATSS_matcher import ATSSMatcher +from monai.data.box_utils import box_iou +from tests.utils import assert_allclose + +TEST_CASES = [] +TEST_CASES.append( + [ + {"num_candidates": 2, "similarity_fn": box_iou, "center_in_gt": False}, + torch.tensor([[0, 1, 2, 3, 2, 5]], dtype=torch.float16), + torch.tensor([[0, 1, 2, 3, 2, 5], [0, 1, 1, 3, 2, 5], [0, 1, 2, 3, 2, 4]], dtype=torch.float16), + [3], + 3, + torch.tensor([0, -1, -1]), + ] +) + + +class TestATSS(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_atss(self, input_param, boxes, anchors, num_anchors_per_level, num_anchors_per_loc, expected_matches): + matcher = ATSSMatcher(**input_param, debug=True) + match_quality_matrix, matches = matcher.compute_matches( + boxes, anchors, num_anchors_per_level, num_anchors_per_loc + ) + assert_allclose(matches, expected_matches, type_test=True, device_test=True, atol=0) + + +if __name__ == "__main__": + unittest.main() From c5cb4417db6167787130d15e6a4076b8ddc96030 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Wed, 25 May 2022 18:43:08 -0400 Subject: [PATCH 02/13] docstring Signed-off-by: Can Zhao --- docs/source/apps.rst | 5 +++++ .../apps/detection/networks/utils/ATSS_matcher.py | 14 +++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 1a11fc62c6..644e92a588 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -134,3 +134,8 @@ Applications :members: .. automodule:: monai.apps.detection.transforms.dictionary :members: + +`Matcher` +~~~~~~~~~ +.. automodule:: monai.apps.detection.networks.utils.ATSS_matcher + :members: diff --git a/monai/apps/detection/networks/utils/ATSS_matcher.py b/monai/apps/detection/networks/utils/ATSS_matcher.py index 2f876ace08..0aa94be16e 100644 --- a/monai/apps/detection/networks/utils/ATSS_matcher.py +++ b/monai/apps/detection/networks/utils/ATSS_matcher.py @@ -61,16 +61,16 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -The functions in this script are the almost the same with nnDetection, +The functions in this script are tadapted from nnDetection, https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py -which is modified from torchvision. +which is adapted from torchvision. These are the changes compared with nndetection: -- comments and docstrings; -- reformat; -- add a debug option to ATSSMatcher to help the users to tune parameters; -- add a corner case return in ATSSMatcher.compute_matches; -- add support for float16 cpu +1) comments and docstrings; +2) reformat; +3) add a debug option to ATSSMatcher to help the users to tune parameters; +4) add a corner case return in ATSSMatcher.compute_matches; +5) add support for float16 cpu """ import logging From 48dc4cd8db645d0f8716a19bd2cc75d40876a15d Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Wed, 25 May 2022 18:51:48 -0400 Subject: [PATCH 03/13] reformat Signed-off-by: Can Zhao --- tests/test_atss_box_matcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_atss_box_matcher.py b/tests/test_atss_box_matcher.py index 43745e3cc2..77092a602e 100644 --- a/tests/test_atss_box_matcher.py +++ b/tests/test_atss_box_matcher.py @@ -10,7 +10,6 @@ # limitations under the License. import unittest -from collections import OrderedDict import torch from parameterized import parameterized From fd87d9fc5e1e5dc03db61f5f0b57511bebcc08d9 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Wed, 25 May 2022 19:00:24 -0400 Subject: [PATCH 04/13] support float16 Signed-off-by: Can Zhao --- monai/apps/detection/networks/utils/ATSS_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/detection/networks/utils/ATSS_matcher.py b/monai/apps/detection/networks/utils/ATSS_matcher.py index 0aa94be16e..b2fbfe54ee 100644 --- a/monai/apps/detection/networks/utils/ATSS_matcher.py +++ b/monai/apps/detection/networks/utils/ATSS_matcher.py @@ -287,7 +287,7 @@ def compute_matches( ious_inf[index] = match_quality_matrix.view(-1)[index] ious_inf = ious_inf.view_as(match_quality_matrix) - matched_vals, matches = ious_inf.max(dim=0) + matched_vals, matches = ious_inf.to(COMPUTE_DTYPE).max(dim=0) matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD # print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}") return match_quality_matrix, matches From 56f18bafcc55e098464a4185cedcfb833896ddb3 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Wed, 25 May 2022 19:04:09 -0400 Subject: [PATCH 05/13] typo Signed-off-by: Can Zhao --- monai/apps/detection/networks/utils/ATSS_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/detection/networks/utils/ATSS_matcher.py b/monai/apps/detection/networks/utils/ATSS_matcher.py index b2fbfe54ee..57050acfce 100644 --- a/monai/apps/detection/networks/utils/ATSS_matcher.py +++ b/monai/apps/detection/networks/utils/ATSS_matcher.py @@ -61,7 +61,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -The functions in this script are tadapted from nnDetection, +The functions in this script are adapted from nnDetection, https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py which is adapted from torchvision. From 0a6a0095638c8ebfb05b623af5095c744dcb8189 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Wed, 25 May 2022 19:51:06 -0400 Subject: [PATCH 06/13] init Signed-off-by: Can Zhao --- monai/apps/detection/networks/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 monai/apps/detection/networks/__init__.py diff --git a/monai/apps/detection/networks/__init__.py b/monai/apps/detection/networks/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/detection/networks/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From a4bf11d79f76118ab457f55a55987a21edb546cc Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Thu, 26 May 2022 12:33:56 -0400 Subject: [PATCH 07/13] mv to detection/ Signed-off-by: Can Zhao --- docs/source/apps.rst | 2 +- monai/apps/detection/networks/utils/__init__.py | 10 ---------- .../detection/{networks => }/utils/ATSS_matcher.py | 0 monai/apps/detection/{networks => utils}/__init__.py | 0 tests/test_atss_box_matcher.py | 2 +- 5 files changed, 2 insertions(+), 12 deletions(-) delete mode 100644 monai/apps/detection/networks/utils/__init__.py rename monai/apps/detection/{networks => }/utils/ATSS_matcher.py (100%) rename monai/apps/detection/{networks => utils}/__init__.py (100%) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 644e92a588..7d480403d8 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -137,5 +137,5 @@ Applications `Matcher` ~~~~~~~~~ -.. automodule:: monai.apps.detection.networks.utils.ATSS_matcher +.. automodule:: monai.apps.detection.utils.ATSS_matcher :members: diff --git a/monai/apps/detection/networks/utils/__init__.py b/monai/apps/detection/networks/utils/__init__.py deleted file mode 100644 index 1e97f89407..0000000000 --- a/monai/apps/detection/networks/utils/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/monai/apps/detection/networks/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py similarity index 100% rename from monai/apps/detection/networks/utils/ATSS_matcher.py rename to monai/apps/detection/utils/ATSS_matcher.py diff --git a/monai/apps/detection/networks/__init__.py b/monai/apps/detection/utils/__init__.py similarity index 100% rename from monai/apps/detection/networks/__init__.py rename to monai/apps/detection/utils/__init__.py diff --git a/tests/test_atss_box_matcher.py b/tests/test_atss_box_matcher.py index 77092a602e..1fc332f6a9 100644 --- a/tests/test_atss_box_matcher.py +++ b/tests/test_atss_box_matcher.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.apps.detection.networks.utils.ATSS_matcher import ATSSMatcher +from monai.apps.detection.utils.ATSS_matcher import ATSSMatcher from monai.data.box_utils import box_iou from tests.utils import assert_allclose From 93cc0a2855f9e2318a48d331b8e00873655906d7 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Fri, 27 May 2022 15:37:51 -0400 Subject: [PATCH 08/13] update docstring Signed-off-by: Can Zhao --- monai/apps/detection/utils/ATSS_matcher.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py index 57050acfce..61c0ba9ee8 100644 --- a/monai/apps/detection/utils/ATSS_matcher.py +++ b/monai/apps/detection/utils/ATSS_matcher.py @@ -83,21 +83,22 @@ from monai.data.box_utils import COMPUTE_DTYPE, box_iou, boxes_center_distance, centers_in_boxes from monai.utils.type_conversion import convert_to_tensor -INF = 100 # not really inf but here it is sufficient +# -INF should be smaller than the lower bound of similarity_fn output. +INF = 100. # not really inf but here it is sufficient class Matcher(ABC): + """ + Base class of Matcher, which matches boxes and anchors to each other + + Args: + similarity_fn: function for similarity computation between + boxes and anchors + """ BELOW_LOW_THRESHOLD: int = -1 BETWEEN_THRESHOLDS: int = -2 - def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou): # type: ignore - """ - Matches boxes and anchors to each other - - Args: - similarity_fn: function for similarity computation between - boxes and anchors - """ + def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou): # type: ignore self.similarity_fn = similarity_fn def __call__( @@ -169,7 +170,7 @@ def __init__( num_candidates: int = 4, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou, # type: ignore center_in_gt: bool = True, - debug=False, + debug: bool = False, ): """ Compute matching based on ATSS From 6a30295dbc1d19dfa3600a440120eb0777c3000e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 May 2022 19:38:35 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/apps/detection/utils/ATSS_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py index 61c0ba9ee8..c897c42aca 100644 --- a/monai/apps/detection/utils/ATSS_matcher.py +++ b/monai/apps/detection/utils/ATSS_matcher.py @@ -98,7 +98,7 @@ class Matcher(ABC): BELOW_LOW_THRESHOLD: int = -1 BETWEEN_THRESHOLDS: int = -2 - def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou): # type: ignore + def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou): # type: ignore self.similarity_fn = similarity_fn def __call__( From 1634241efa036992c1354d4e157c8d5a8b9388a5 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Fri, 27 May 2022 15:41:29 -0400 Subject: [PATCH 10/13] update INF Signed-off-by: Can Zhao --- monai/apps/detection/utils/ATSS_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py index 61c0ba9ee8..e749b2adca 100644 --- a/monai/apps/detection/utils/ATSS_matcher.py +++ b/monai/apps/detection/utils/ATSS_matcher.py @@ -84,7 +84,7 @@ from monai.utils.type_conversion import convert_to_tensor # -INF should be smaller than the lower bound of similarity_fn output. -INF = 100. # not really inf but here it is sufficient +INF = float("inf") class Matcher(ABC): From d80ed0622093d110ef17a5e92da69ec62258c130 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 27 May 2022 20:04:25 +0000 Subject: [PATCH 11/13] [MONAI] code formatting Signed-off-by: monai-bot --- monai/apps/detection/utils/ATSS_matcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py index 7bb457c12d..aba1dfc0a7 100644 --- a/monai/apps/detection/utils/ATSS_matcher.py +++ b/monai/apps/detection/utils/ATSS_matcher.py @@ -95,6 +95,7 @@ class Matcher(ABC): similarity_fn: function for similarity computation between boxes and anchors """ + BELOW_LOW_THRESHOLD: int = -1 BETWEEN_THRESHOLDS: int = -2 From d1d1f8ccc618dc036447b071880ed6d40e7342c9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 28 May 2022 08:14:47 +0100 Subject: [PATCH 12/13] format/typo fixes Signed-off-by: Wenqi Li --- monai/apps/detection/transforms/dictionary.py | 2 +- monai/apps/detection/utils/ATSS_matcher.py | 44 +++++++------------ tests/test_atss_box_matcher.py | 5 +-- 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index b802ebcfe2..aa7b0aaa75 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -706,7 +706,7 @@ class ClipBoxToImaged(MapTransform): Args: box_keys: The single key to pick box data for transformation. The box mode is assumed to be ``StandardMode``. - label_keys: Keys that represents the lables corresponding to the ``box_keys``. Multiple keys are allowed. + label_keys: Keys that represents the labels corresponding to the ``box_keys``. Multiple keys are allowed. box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` and ``label_keys`` are attached. remove_empty: whether to remove the boxes that are actually empty diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py index aba1dfc0a7..c34fe436e7 100644 --- a/monai/apps/detection/utils/ATSS_matcher.py +++ b/monai/apps/detection/utils/ATSS_matcher.py @@ -115,12 +115,10 @@ def __call__( num_anchors_per_loc: number of anchors per position Returns: - - matrix which contains the similarity from each boxes - to each anchor [N, M] + - matrix which contains the similarity from each boxes to each anchor [N, M] - vector which contains the matched box index for all anchors (if background `BELOW_LOW_THRESHOLD` is used - and if it should be ignored `BETWEEN_THRESHOLDS` is used) - [M] + and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M] Note: ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`, @@ -133,14 +131,13 @@ def __call__( match_quality_matrix = torch.tensor([]).to(anchors) matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD) return match_quality_matrix, matches - else: - # at least one ground truth - return self.compute_matches( - boxes=boxes, - anchors=anchors, - num_anchors_per_level=num_anchors_per_level, - num_anchors_per_loc=num_anchors_per_loc, - ) + # at least one ground truth + return self.compute_matches( + boxes=boxes, + anchors=anchors, + num_anchors_per_level=num_anchors_per_level, + num_anchors_per_loc=num_anchors_per_loc, + ) def compute_matches( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int @@ -155,12 +152,10 @@ def compute_matches( num_anchors_per_loc: number of anchors per position Returns: - - matrix which contains the similarity from each boxes - to each anchor [N, M] + - matrix which contains the similarity from each boxes to each anchor [N, M] - vector which contains the matched box index for all anchors (if background `BELOW_LOW_THRESHOLD` is used - and if it should be ignored `BETWEEN_THRESHOLDS` is used) - [M] + and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M] """ raise NotImplementedError @@ -174,18 +169,16 @@ def __init__( debug: bool = False, ): """ - Compute matching based on ATSS - https://arxiv.org/abs/1912.02424 + Compute matching based on ATSS https://arxiv.org/abs/1912.02424 `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection` Args: num_candidates: number of positions to select candidates from. Smaller value will result in a higher matcher threshold and less matched candidates. - similarity_fn: function for similarity computation between - boxes and anchors + similarity_fn: function for similarity computation between boxes and anchors center_in_gt: If False (default), matched anchor center points do not need - to lie withing the ground truth box. Recommand False for small objects. + to lie withing the ground truth box. Recommend False for small objects. If True, will result in a strict matcher and less matched candidates. debug: if True, will print the matcher threshold in order to tune ``num_candidates`` and ``center_in_gt``. @@ -205,8 +198,7 @@ def compute_matches( """ Compute matches according to ATTS for a single image Adapted from - (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss - /loss.py#L180-L184) + (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss/loss.py#L180-L184) Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` @@ -215,12 +207,10 @@ def compute_matches( num_anchors_per_loc: number of anchors per position Returns: - Tensor: matrix which contains the similarity from each boxes - to each anchor [N, M] + Tensor: matrix which contains the similarity from each boxes to each anchor [N, M] Tensor: vector which contains the matched box index for all anchors (if background `BELOW_LOW_THRESHOLD` is used - and if it should be ignored `BETWEEN_THRESHOLDS` is used) - [M] + and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M] Note: ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`, diff --git a/tests/test_atss_box_matcher.py b/tests/test_atss_box_matcher.py index 1fc332f6a9..093641bb2f 100644 --- a/tests/test_atss_box_matcher.py +++ b/tests/test_atss_box_matcher.py @@ -18,8 +18,7 @@ from monai.data.box_utils import box_iou from tests.utils import assert_allclose -TEST_CASES = [] -TEST_CASES.append( +TEST_CASES = [ [ {"num_candidates": 2, "similarity_fn": box_iou, "center_in_gt": False}, torch.tensor([[0, 1, 2, 3, 2, 5]], dtype=torch.float16), @@ -28,7 +27,7 @@ 3, torch.tensor([0, -1, -1]), ] -) +] class TestATSS(unittest.TestCase): From 2a84778e05a64bf2f71d14949c4d46baeac1fc7d Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sat, 28 May 2022 17:48:03 +0000 Subject: [PATCH 13/13] [MONAI] code formatting Signed-off-by: monai-bot --- monai/inferers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 736eef8309..0f084c1ff5 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -155,7 +155,7 @@ def sliding_window_inference( raise RuntimeError( "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." ) from e - importance_map = convert_data_type(importance_map, torch.Tensor, device=device, dtype=compute_dtype)[0] # type: ignore + importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore # handle non-positive weights min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)