Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class DiceMetric(CumulativeIterationMetric):
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.

"""

Expand All @@ -48,11 +51,13 @@ def __init__(
include_background: bool = True,
reduction: Union[MetricReduction, str] = MetricReduction.MEAN,
get_not_nans: bool = False,
ignore_empty: bool = True,
) -> None:
super().__init__()
self.include_background = include_background
self.reduction = reduction
self.get_not_nans = get_not_nans
self.ignore_empty = ignore_empty

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore
"""
Expand All @@ -77,7 +82,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
if dims < 3:
raise ValueError("y_pred should have at least three dimensions.")
# compute dice (BxC) for each channel for each batch
return compute_meandice(y_pred=y_pred, y=y, include_background=self.include_background)
return compute_meandice(
y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty
)

def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
"""
Expand All @@ -98,7 +105,9 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # typ
return (f, not_nans) if self.get_not_nans else f


def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor:
def compute_meandice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
) -> torch.Tensor:
"""Computes Dice score metric from full size Tensor and collects average.

Args:
Expand All @@ -109,6 +118,9 @@ def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background:
The values should be binarized.
include_background: whether to skip Dice computation on the first channel of
the predicted output. Defaults to True.
ignore_empty: whether to ignore empty ground truth cases during calculation.
If `True`, NaN value will be set for empty ground truth cases.
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
Comment thread
wyli marked this conversation as resolved.

Returns:
Dice scores per batch and per class, (shape [batch_size, num_classes]).
Expand Down Expand Up @@ -136,4 +148,6 @@ def compute_meandice(y_pred: torch.Tensor, y: torch.Tensor, include_background:
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
denominator = y_o + y_pred_o

return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device))
if ignore_empty is True:
return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device))
return torch.where(denominator > 0, (2.0 * intersection) / denominator, torch.tensor(1.0, device=y_o.device))
12 changes: 11 additions & 1 deletion tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,19 @@
[[1.0000, 1.0000], [1.0000, 1.0000]],
]

TEST_CASE_11 = [
{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3)), "ignore_empty": False},
[[1.0000, 1.0000], [1.0000, 1.0000]],
]

TEST_CASE_12 = [
{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3)), "ignore_empty": False},
[[0.0000, 0.0000], [0.0000, 0.0000]],
]


class TestComputeMeanDice(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
def test_value(self, input_data, expected_value):
result = compute_meandice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
Expand Down