diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 109240f313..69b21195c2 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -10,6 +10,9 @@ Metrics ----------- .. autofunction:: compute_meandice +.. autoclass:: DiceMetric + :members: + `Area under the ROC curve` -------------------------- .. autofunction:: compute_roc_auc diff --git a/examples/segmentation_3d/unet_evaluation_array.py b/examples/segmentation_3d/unet_evaluation_array.py index ba4c855474..2fcb14ff57 100644 --- a/examples/segmentation_3d/unet_evaluation_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -25,7 +25,7 @@ from monai.networks.nets import UNet from monai.data import create_test_image_3d, NiftiSaver, NiftiDataset from monai.inferers import sliding_window_inference -from monai.metrics import compute_meandice +from monai.metrics import DiceMetric def main(): @@ -52,6 +52,7 @@ def main(): val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) # sliding window inference for one image at every iteration val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") device = torch.device("cuda:0") model = UNet( @@ -75,11 +76,9 @@ def main(): roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice( - y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True - ) + value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) - metric_sum += value.sum().item() + metric_sum += value.item() * len(value) val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data[2]) metric = metric_sum / metric_count diff --git a/examples/segmentation_3d/unet_evaluation_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py index b66394d3f5..353416894e 100644 --- a/examples/segmentation_3d/unet_evaluation_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -23,7 +23,7 @@ import monai from monai.data import list_data_collate, create_test_image_3d, NiftiSaver from monai.inferers import sliding_window_inference -from monai.metrics import compute_meandice +from monai.metrics import DiceMetric from monai.networks.nets import UNet from monai.transforms import Compose, LoadNiftid, AsChannelFirstd, ScaleIntensityd, ToTensord from monai.engines import get_devices_spec @@ -59,9 +59,8 @@ def main(): ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration - val_loader = DataLoader( - val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # try to use all the available GPUs devices = get_devices_spec(None) @@ -91,11 +90,9 @@ def main(): roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice( - y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True - ) + value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) - metric_sum += value.sum().item() + metric_sum += value.item() * len(value) val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data["img_meta_dict"]) metric = metric_sum / metric_count diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index 8079eb03c6..b42f8d1140 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -25,7 +25,7 @@ from monai.data import NiftiDataset, create_test_image_3d from monai.inferers import sliding_window_inference from monai.transforms import Compose, AddChannel, ScaleIntensity, RandSpatialCrop, RandRotate90, ToTensor -from monai.metrics import compute_meandice +from monai.metrics import DiceMetric from monai.visualize.img2tensorboard import plot_2d_or_3d_image @@ -81,6 +81,7 @@ def main(): # create a validation data loader val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") @@ -137,11 +138,9 @@ def main(): roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice( - y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True - ) + value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) - metric_sum += value.sum().item() + metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: diff --git a/examples/segmentation_3d/unet_training_dict.py b/examples/segmentation_3d/unet_training_dict.py index c2ea913e76..379b63181f 100644 --- a/examples/segmentation_3d/unet_training_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -33,7 +33,7 @@ ) from monai.data import create_test_image_3d, list_data_collate from monai.inferers import sliding_window_inference -from monai.metrics import compute_meandice +from monai.metrics import DiceMetric from monai.visualize import plot_2d_or_3d_image @@ -83,9 +83,7 @@ def main(): # define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training - check_loader = DataLoader( - check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) + check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) check_data = monai.utils.misc.first(check_loader) print(check_data["img"].shape, check_data["seg"].shape) @@ -102,9 +100,8 @@ def main(): ) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - val_loader = DataLoader( - val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0") @@ -161,11 +158,9 @@ def main(): roi_size = (96, 96, 96) sw_batch_size = 4 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice( - y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True - ) + value = dice_metric(y_pred=val_outputs, y=val_labels) metric_count += len(value) - metric_sum += value.sum().item() + metric_sum += value.item() * len(value) metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 701f3a76c7..50b26d48c5 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -13,7 +13,7 @@ import torch -from monai.metrics import compute_meandice +from monai.metrics import DiceMetric from monai.utils import exact_version, optional_import NotComputableError, _ = optional_import("ignite.exceptions", "0.3.0", exact_version, "NotComputableError") @@ -55,12 +55,14 @@ def __init__( :py:meth:`monai.metrics.meandice.compute_meandice` """ super().__init__(output_transform, device=device) - self.include_background = include_background - self.to_onehot_y = to_onehot_y - self.mutually_exclusive = mutually_exclusive - self.sigmoid = sigmoid - self.logit_thresh = logit_thresh - + self.dice = DiceMetric( + include_background=include_background, + to_onehot_y=to_onehot_y, + mutually_exclusive=mutually_exclusive, + sigmoid=sigmoid, + logit_thresh=logit_thresh, + reduction="mean", + ) self._sum = 0 self._num_examples = 0 @@ -74,24 +76,12 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]): if not len(output) == 2: raise ValueError("MeanDice metric can only support y_pred and y.") y_pred, y = output - scores = compute_meandice( - y_pred, - y, - self.include_background, - self.to_onehot_y, - self.mutually_exclusive, - self.sigmoid, - self.logit_thresh, - ) + score = self.dice(y_pred, y) + not_nans = self.dice.not_nans.item() # add all items in current batch - for batch in scores: - not_nan = ~torch.isnan(batch) - if not_nan.sum() == 0: - continue - class_avg = batch[not_nan].mean().item() - self._sum += class_avg - self._num_examples += 1 + self._sum += score.item() * not_nans + self._num_examples += not_nans @sync_all_reduce("_sum", "_num_examples") def compute(self): diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index f02c874d93..359cbbfbad 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -18,14 +18,26 @@ class DiceMetric: """ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. - Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Axis N of `input` is expected to have logit predictions for each class rather than being image channels, - while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the - intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0, - this value should be small. The `include_background` class attribute can be set to False for an instance of - DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. - If the non-background segmentations are small compared to the total image size they can get overwhelmed by - the signal from the background so excluding it in such cases helps convergence. + Input logits `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + Axis N of `y_preds` is expected to have logit predictions for each class rather than being image channels, + while the same axis of `y` can be 1 or N (one-hot format). The `include_background` class attribute can be + set to False for an instance of DiceLoss to exclude the first category (channel index 0) which is by + convention assumed to be background. If the non-background segmentations are small compared to the total + image size they can get overwhelmed by the signal from the background so excluding it in such cases helps + convergence. + + Args: + include_background: whether to skip Dice computation on the first channel of + the predicted output. Defaults to True. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. Defaults to False. + sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. + logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) + `y_pred` into a binary matrix. Defaults to 0.5. + reduction: define the mode to reduce computation result of 1 batch data. + available modes: `none`, `mean`, `sum`, `mean_batch`, `sum_batch`, `mean_batch`, `sum_batch`. + default is `mean`, average on channel dim then on batch dim. """ @@ -40,7 +52,7 @@ def __init__( ): super().__init__() - if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: + if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel"]: raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.") self.include_background = include_background @@ -52,12 +64,12 @@ def __init__( self.not_nans = None # keep track for valid elements in the batch - def __call__(self, input: torch.Tensor, target: torch.Tensor): + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): # compute dice (BxC) for each channel for each batch f = compute_meandice( - y_pred=input, - y=target, + y_pred=y_pred, + y=y, include_background=self.include_background, to_onehot_y=self.to_onehot_y, mutually_exclusive=self.mutually_exclusive, @@ -75,13 +87,13 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor): t_zero = torch.zeros(1, device=f.device, dtype=torch.float) if self.reduction == "mean": - # 2 steps, first, mean by batch (accounting for nans), then by channel + # 2 steps, first, mean by channel (accounting for nans), then by batch - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + not_nans = not_nans.sum(dim=1) + f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average not_nans = not_nans.sum() - f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # channel average + f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # batch average elif self.reduction == "sum": not_nans = not_nans.sum() @@ -92,12 +104,19 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor): elif self.reduction == "sum_batch": not_nans = not_nans.sum(dim=0) f = f.sum(dim=0) # the batch sum + elif self.reduction == "mean_channel": + not_nans = not_nans.sum(dim=1) + f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average + elif self.reduction == "sum_channel": + not_nans = not_nans.sum(dim=1) + f = f.sum(dim=1) # the channel sum elif self.reduction == "none": pass else: raise ValueError(f"reduction={self.reduction} is invalid.") - self.not_nans = not_nans # preserve, since we may need it later to know how many elements were valid + # save not_nans since we may need it later to know how many elements were valid + self.not_nans = not_nans return f @@ -186,4 +205,4 @@ def compute_meandice( denominator = y_o + y_pred_o f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device)) - return f # returns array of Dice shape: [Batch, n_classes] + return f # returns array of Dice shape: [batch, n_classes] diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 9f21eb5d98..2831cad8b1 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -63,13 +63,13 @@ TEST_CASE_4 = [ {"include_background": True, "to_onehot_y": True, "reduction": "mean_batch"}, { - "input": torch.tensor( + "y_pred": torch.tensor( [ [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], ] ), - "target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[1.0, 1.0], [2.0, 0.0]]]]), + "y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[1.0, 1.0], [2.0, 0.0]]]]), }, [0.6786, 0.4000, 0.6667], ] @@ -77,13 +77,13 @@ TEST_CASE_5 = [ {"include_background": True, "to_onehot_y": True, "reduction": "sum_batch"}, { - "input": torch.tensor( + "y_pred": torch.tensor( [ [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], ] ), - "target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]), + "y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]), }, [1.7143, 0.0000, 0.0000], ] @@ -91,13 +91,13 @@ TEST_CASE_6 = [ {"to_onehot_y": True, "include_background": False, "reduction": "sum_batch"}, { - "input": torch.tensor( + "y_pred": torch.tensor( [ [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], ] ), - "target": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]), + "y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 0.0]]]]), }, [0.0000, 0.0000], ] @@ -127,8 +127,8 @@ def test_value_class(self, input_data, expected_value): # same test as for compute_meandice vals = dict() - vals["input"] = input_data.pop("y_pred") - vals["target"] = input_data.pop("y") + vals["y_pred"] = input_data.pop("y_pred") + vals["y"] = input_data.pop("y") dice_metric = DiceMetric(**input_data, reduction="none") result = dice_metric(**vals) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index fdbb49deed..edd8553635 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -24,7 +24,7 @@ import monai from monai.data import create_test_image_3d, NiftiSaver, list_data_collate from monai.inferers import sliding_window_inference -from monai.metrics import compute_meandice +from monai.metrics import DiceMetric from monai.networks.nets import UNet from monai.transforms import ( Compose, @@ -79,19 +79,11 @@ def run_training_test(root_dir, device=torch.device("cuda:0"), cachedataset=Fals else: train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training - train_loader = DataLoader( - train_ds, - batch_size=2, - shuffle=True, - num_workers=4, - collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available(), - ) + train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - val_loader = DataLoader( - val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") # create UNet, DiceLoss and Adam optimizer model = monai.networks.nets.UNet( @@ -146,11 +138,10 @@ def run_training_test(root_dir, device=torch.device("cuda:0"), cachedataset=Fals val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice( - y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True - ) - metric_count += len(value) - metric_sum += value.sum().item() + value = dice_metric(y_pred=val_outputs, y=val_labels) + not_nans = dice_metric.not_nans.item() + metric_count += not_nans + metric_sum += value.item() * not_nans metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: @@ -189,9 +180,8 @@ def run_inference_test(root_dir, device=torch.device("cuda:0")): ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inferene need to input 1 image in every iteration - val_loader = DataLoader( - val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") model = UNet( dimensions=3, @@ -214,11 +204,10 @@ def run_inference_test(root_dir, device=torch.device("cuda:0")): # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = compute_meandice( - y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, sigmoid=True - ) - metric_count += len(value) - metric_sum += value.sum().item() + value = dice_metric(y_pred=val_outputs, y=val_labels) + not_nans = dice_metric.not_nans.item() + metric_count += not_nans + metric_sum += value.item() * not_nans val_outputs = (val_outputs.sigmoid() >= 0.5).float() saver.save_batch(val_outputs, val_data["img_meta_dict"]) metric = metric_sum / metric_count