From 118cb3c077c48f7ebc7248066d18f76e07de37fb Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sun, 24 Apr 2022 14:36:45 +0800 Subject: [PATCH] enhance mean dice Signed-off-by: Yiheng Wang --- monai/metrics/meandice.py | 20 +++++++++++++++++--- tests/test_compute_meandice.py | 12 +++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 4179420804..0783d18909 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -40,6 +40,9 @@ class DiceMetric(CumulativeIterationMetric): ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric. + ignore_empty: whether to ignore empty ground truth cases during calculation. + If `True`, NaN value will be set for empty ground truth cases. + If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. """ @@ -48,11 +51,13 @@ def __init__( include_background: bool = True, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False, + ignore_empty: bool = True, ) -> None: super().__init__() self.include_background = include_background self.reduction = reduction self.get_not_nans = get_not_nans + self.ignore_empty = ignore_empty def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore """ @@ -77,7 +82,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor if dims < 3: raise ValueError("y_pred should have at least three dimensions.") # compute dice (BxC) for each channel for each batch - return compute_meandice(y_pred=y_pred, y=y, include_background=self.include_background) + return compute_meandice( + y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty + ) def aggregate(self): """ @@ -93,7 +100,9 @@ def aggregate(self): return (f, not_nans) if self.get_not_nans else f -def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor: +def compute_meandice( + y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True +) -> torch.Tensor: """Computes Dice score metric from full size Tensor and collects average. Args: @@ -104,6 +113,9 @@ def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background: The values should be binarized. include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to True. + ignore_empty: whether to ignore empty ground truth cases during calculation. + If `True`, NaN value will be set for empty ground truth cases. + If `False`, 1 will be set if the predictions of empty ground truth cases are also empty. Returns: Dice scores per batch and per class, (shape [batch_size, num_classes]). @@ -131,4 +143,6 @@ def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background: y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o - return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) + if ignore_empty is True: + return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) + return torch.where(denominator > 0, (2.0 * intersection) / denominator, torch.tensor(1.0, device=y_o.device)) diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index ad66ed672a..a6b336f2d9 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -172,9 +172,19 @@ [[1.0000, 1.0000], [1.0000, 1.0000]], ] +TEST_CASE_11 = [ + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3)), "ignore_empty": False}, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] + +TEST_CASE_12 = [ + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3)), "ignore_empty": False}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] + class TestComputeMeanDice(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12]) def test_value(self, input_data, expected_value): result = compute_meandice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)