Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Metrics
-----------
.. autofunction:: compute_meandice

.. autoclass:: DiceMetric
:members:

`Area under the ROC curve`
--------------------------
.. autofunction:: compute_roc_auc
9 changes: 4 additions & 5 deletions examples/segmentation_3d/unet_evaluation_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.nets import UNet
from monai.data import create_test_image_3d, NiftiSaver, NiftiDataset
from monai.inferers import sliding_window_inference
from monai.metrics import compute_meandice
from monai.metrics import DiceMetric


def main():
Expand All @@ -52,6 +52,7 @@ def main():
val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
# sliding window inference for one image at every iteration
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

device = torch.device("cuda:0")
model = UNet(
Expand All @@ -75,11 +76,9 @@ def main():
roi_size = (96, 96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
value = compute_meandice(
y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True
)
value = dice_metric(y_pred=val_outputs, y=val_labels)
metric_count += len(value)
metric_sum += value.sum().item()
metric_sum += value.item() * len(value)
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
saver.save_batch(val_outputs, val_data[2])
metric = metric_sum / metric_count
Expand Down
13 changes: 5 additions & 8 deletions examples/segmentation_3d/unet_evaluation_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import monai
from monai.data import list_data_collate, create_test_image_3d, NiftiSaver
from monai.inferers import sliding_window_inference
from monai.metrics import compute_meandice
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Compose, LoadNiftid, AsChannelFirstd, ScaleIntensityd, ToTensord
from monai.engines import get_devices_spec
Expand Down Expand Up @@ -59,9 +59,8 @@ def main():
)
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
# sliding window inference need to input 1 image in every iteration
val_loader = DataLoader(
val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

# try to use all the available GPUs
devices = get_devices_spec(None)
Expand Down Expand Up @@ -91,11 +90,9 @@ def main():
roi_size = (96, 96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
value = compute_meandice(
y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True
)
value = dice_metric(y_pred=val_outputs, y=val_labels)
metric_count += len(value)
metric_sum += value.sum().item()
metric_sum += value.item() * len(value)
val_outputs = (val_outputs.sigmoid() >= 0.5).float()
saver.save_batch(val_outputs, val_data["img_meta_dict"])
metric = metric_sum / metric_count
Expand Down
9 changes: 4 additions & 5 deletions examples/segmentation_3d/unet_training_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.data import NiftiDataset, create_test_image_3d
from monai.inferers import sliding_window_inference
from monai.transforms import Compose, AddChannel, ScaleIntensity, RandSpatialCrop, RandRotate90, ToTensor
from monai.metrics import compute_meandice
from monai.metrics import DiceMetric
from monai.visualize.img2tensorboard import plot_2d_or_3d_image


Expand Down Expand Up @@ -81,6 +81,7 @@ def main():
# create a validation data loader
val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
Expand Down Expand Up @@ -137,11 +138,9 @@ def main():
roi_size = (96, 96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
value = compute_meandice(
y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True
)
value = dice_metric(y_pred=val_outputs, y=val_labels)
metric_count += len(value)
metric_sum += value.sum().item()
metric_sum += value.item() * len(value)
metric = metric_sum / metric_count
metric_values.append(metric)
if metric > best_metric:
Expand Down
17 changes: 6 additions & 11 deletions examples/segmentation_3d/unet_training_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from monai.data import create_test_image_3d, list_data_collate
from monai.inferers import sliding_window_inference
from monai.metrics import compute_meandice
from monai.metrics import DiceMetric
from monai.visualize import plot_2d_or_3d_image


Expand Down Expand Up @@ -83,9 +83,7 @@ def main():
# define dataset, data loader
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
check_loader = DataLoader(
check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
)
check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)

Expand All @@ -102,9 +100,8 @@ def main():
)
# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(
val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()
)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
Expand Down Expand Up @@ -161,11 +158,9 @@ def main():
roi_size = (96, 96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
value = compute_meandice(
y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True
)
value = dice_metric(y_pred=val_outputs, y=val_labels)
metric_count += len(value)
metric_sum += value.sum().item()
metric_sum += value.item() * len(value)
metric = metric_sum / metric_count
metric_values.append(metric)
if metric > best_metric:
Expand Down
36 changes: 13 additions & 23 deletions monai/handlers/mean_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from monai.metrics import compute_meandice
from monai.metrics import DiceMetric
from monai.utils import exact_version, optional_import

NotComputableError, _ = optional_import("ignite.exceptions", "0.3.0", exact_version, "NotComputableError")
Expand Down Expand Up @@ -55,12 +55,14 @@ def __init__(
:py:meth:`monai.metrics.meandice.compute_meandice`
"""
super().__init__(output_transform, device=device)
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.mutually_exclusive = mutually_exclusive
self.sigmoid = sigmoid
self.logit_thresh = logit_thresh

self.dice = DiceMetric(
include_background=include_background,
to_onehot_y=to_onehot_y,
mutually_exclusive=mutually_exclusive,
sigmoid=sigmoid,
logit_thresh=logit_thresh,
reduction="mean",
)
self._sum = 0
self._num_examples = 0

Expand All @@ -74,24 +76,12 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]):
if not len(output) == 2:
raise ValueError("MeanDice metric can only support y_pred and y.")
y_pred, y = output
scores = compute_meandice(
y_pred,
y,
self.include_background,
self.to_onehot_y,
self.mutually_exclusive,
self.sigmoid,
self.logit_thresh,
)
score = self.dice(y_pred, y)
not_nans = self.dice.not_nans.item()

# add all items in current batch
for batch in scores:
not_nan = ~torch.isnan(batch)
if not_nan.sum() == 0:
continue
class_avg = batch[not_nan].mean().item()
self._sum += class_avg
self._num_examples += 1
self._sum += score.item() * not_nans
self._num_examples += not_nans

@sync_all_reduce("_sum", "_num_examples")
def compute(self):
Expand Down
55 changes: 37 additions & 18 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,26 @@
class DiceMetric:
"""
Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks.
Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
Axis N of `input` is expected to have logit predictions for each class rather than being image channels,
while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the
intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0,
this value should be small. The `include_background` class attribute can be set to False for an instance of
DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background.
If the non-background segmentations are small compared to the total image size they can get overwhelmed by
the signal from the background so excluding it in such cases helps convergence.
Input logits `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]).
Axis N of `y_preds` is expected to have logit predictions for each class rather than being image channels,
while the same axis of `y` can be 1 or N (one-hot format). The `include_background` class attribute can be
set to False for an instance of DiceLoss to exclude the first category (channel index 0) which is by
convention assumed to be background. If the non-background segmentations are small compared to the total
image size they can get overwhelmed by the signal from the background so excluding it in such cases helps
convergence.

Args:
include_background: whether to skip Dice computation on the first channel of
the predicted output. Defaults to True.
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using
a combination of argmax and to_onehot. Defaults to False.
sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False.
logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`)
`y_pred` into a binary matrix. Defaults to 0.5.
reduction: define the mode to reduce computation result of 1 batch data.
available modes: `none`, `mean`, `sum`, `mean_batch`, `sum_batch`, `mean_batch`, `sum_batch`.
default is `mean`, average on channel dim then on batch dim.

"""

Expand All @@ -40,7 +52,7 @@ def __init__(
):
super().__init__()

if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]:
if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"]:
raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.")

self.include_background = include_background
Expand All @@ -52,12 +64,12 @@ def __init__(

self.not_nans = None # keep track for valid elements in the batch

def __call__(self, input: torch.Tensor, target: torch.Tensor):
def __call__(self, y_pred: torch.Tensor, y: torch.Tensor):

# compute dice (BxC) for each channel for each batch
f = compute_meandice(
y_pred=input,
y=target,
y_pred=y_pred,
y=y,
include_background=self.include_background,
to_onehot_y=self.to_onehot_y,
mutually_exclusive=self.mutually_exclusive,
Expand All @@ -75,13 +87,13 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor):
t_zero = torch.zeros(1, device=f.device, dtype=torch.float)

if self.reduction == "mean":
# 2 steps, first, mean by batch (accounting for nans), then by channel
# 2 steps, first, mean by channel (accounting for nans), then by batch

not_nans = not_nans.sum(dim=0)
f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average
not_nans = not_nans.sum(dim=1)
f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average

not_nans = not_nans.sum()
f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # channel average
f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # batch average

elif self.reduction == "sum":
not_nans = not_nans.sum()
Expand All @@ -92,12 +104,19 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor):
elif self.reduction == "sum_batch":
not_nans = not_nans.sum(dim=0)
f = f.sum(dim=0) # the batch sum
elif self.reduction == "mean_channel":
not_nans = not_nans.sum(dim=1)
f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average
elif self.reduction == "sum_channel":
not_nans = not_nans.sum(dim=1)
f = f.sum(dim=1) # the channel sum
elif self.reduction == "none":
pass
else:
raise ValueError(f"reduction={self.reduction} is invalid.")

self.not_nans = not_nans # preserve, since we may need it later to know how many elements were valid
# save not_nans since we may need it later to know how many elements were valid
self.not_nans = not_nans

return f

Expand Down Expand Up @@ -186,4 +205,4 @@ def compute_meandice(
denominator = y_o + y_pred_o

f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device))
return f # returns array of Dice shape: [Batch, n_classes]
return f # returns array of Dice shape: [batch, n_classes]
16 changes: 8 additions & 8 deletions tests/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,41 +63,41 @@
TEST_CASE_4 = [
{"include_background": True, "to_onehot_y": True, "reduction": "mean_batch"},
{
"input": torch.tensor(
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[1.0, 1.0], [2.0, 0.0]]]]),
"y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[1.0, 1.0], [2.0, 0.0]]]]),
},
[0.6786, 0.4000, 0.6667],
]

TEST_CASE_5 = [
{"include_background": True, "to_onehot_y": True, "reduction": "sum_batch"},
{
"input": torch.tensor(
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]),
"y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]),
},
[1.7143, 0.0000, 0.0000],
]

TEST_CASE_6 = [
{"to_onehot_y": True, "include_background": False, "reduction": "sum_batch"},
{
"input": torch.tensor(
"y_pred": torch.tensor(
[
[[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
[[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]),
"y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]),
},
[0.0000, 0.0000],
]
Expand Down Expand Up @@ -127,8 +127,8 @@ def test_value_class(self, input_data, expected_value):

# same test as for compute_meandice
vals = dict()
vals["input"] = input_data.pop("y_pred")
vals["target"] = input_data.pop("y")
vals["y_pred"] = input_data.pop("y_pred")
vals["y"] = input_data.pop("y")
dice_metric = DiceMetric(**input_data, reduction="none")
result = dice_metric(**vals)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
Expand Down
Loading