diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index ed66bfab7b..23432b32bb 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .meandice import compute_meandice, DiceMetric -from .rocauc import compute_roc_auc +from monai.metrics.functional import * +from monai.metrics.metric import Metric, CumulativeMetric +from monai.metrics.dice import Dice, CumulativeDice +from monai.metrics.rocauc import ROCOUC, CumulativeROCAUC diff --git a/monai/metrics/dice.py b/monai/metrics/dice.py new file mode 100644 index 0000000000..4d853b4971 --- /dev/null +++ b/monai/metrics/dice.py @@ -0,0 +1,183 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.metrics.functional.meandice import compute_meandice, _convert_predictions +from monai.metrics.metric import Metric, CumulativeMetric + + +class Dice(Metric): + """ + Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. + Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). + Axis N of `input` is expected to have logit predictions for each class rather than being image channels, + while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the + intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0, + this value should be small. The `include_background` class attribute can be set to False for an instance of + DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. + If the non-background segmentations are small compared to the total image size they can get overwhelmed by + the signal from the background so excluding it in such cases helps convergence. + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + mutually_exclusive: bool = False, + sigmoid: bool = False, + logit_thresh: float = 0.5, + reduction: str = "mean", + ): + """ + Args: + include_background: whether to skip Dice computation on the first channel of + the predicted output. Defaults to True. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. Defaults to False. + sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. + logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) + `y_pred` into a binary matrix. Defaults to 0.5. + """ + super().__init__() + + if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: + raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.") + + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.mutually_exclusive = mutually_exclusive + self.sigmoid = sigmoid + self.logit_thresh = logit_thresh + self.reduction = reduction + + self.not_nans = None # keep track for valid elements in the batch + + def __call__(self, input: torch.Tensor, target: torch.Tensor) -> float: + + # compute dice (BxC) for each channel for each batch + f = compute_meandice( + y_pred=input, + y=target, + include_background=self.include_background, + to_onehot_y=self.to_onehot_y, + mutually_exclusive=self.mutually_exclusive, + sigmoid=self.sigmoid, + logit_thresh=self.logit_thresh, + ) + + # some dice elements might be Nan (if ground truth y was missing (zeros)) + # we need to account for it + + nans = torch.isnan(f) + not_nans = (~nans).float() + f[nans] = 0 + + t_zero = torch.zeros(1, device=f.device, dtype=torch.float) + + if self.reduction == "mean": + # 2 steps, first, mean by batch (accounting for nans), then by channel + + not_nans = not_nans.sum(dim=0) + f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + + not_nans = not_nans.sum() + f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # channel average + + elif self.reduction == "sum": + not_nans = not_nans.sum() + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == "mean_batch": + not_nans = not_nans.sum(dim=0) + f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + elif self.reduction == "sum_batch": + not_nans = not_nans.sum(dim=0) + f = f.sum(dim=0) # the batch sum + elif self.reduction == "none": + pass + else: + raise ValueError(f"reduction={self.reduction} is invalid.") + + self.not_nans = not_nans # preserve, since we may need it later to know how many elements were valid + + return f + + +class CumulativeDice(CumulativeMetric): + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + mutually_exclusive: bool = False, + sigmoid: bool = False, + logit_thresh: float = 0.5, + reduction: str = "mean", + ): + """ + Args: + include_background: whether to skip Dice computation on the first channel of + the predicted output. Defaults to True. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. Defaults to False. + sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. + logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) + `y_pred` into a binary matrix. Defaults to 0.5. + """ + super().__init__() + + if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: + raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.") + + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.mutually_exclusive = mutually_exclusive + self.sigmoid = sigmoid + self.logit_thresh = logit_thresh + self.reduction = reduction + + self._intersection: torch.Tensor = 0 + self._sum_y: torch.Tensor = 0 + self._sum_y_pred: torch.Tensor = 0 + + def __call__(self): + denominator = self._sum_y + self._sum_y_pred + f = torch.where( + self._sum_y > 0, + (2.0 * self._intersection) / denominator, + torch.tensor(float("nan"), device=self._sum_y.device), + ) + return f # returns array of Dice shape: [Batch, n_classes] + + def add_sample(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: + y_pred, y = _convert_predictions( + y_pred, + y, + include_background=self.include_background, + to_onehot_y=self.to_onehot_y, + mutually_exclusive=self.mutually_exclusive, + sigmoid=self.sigmoid, + logit_thresh=self.logit_thresh, + ) + n_len = len(y_pred.shape) + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis = list(range(2, n_len)) + + self._intersection = torch.sum(y * y_pred, dim=reduce_axis) + self._intersection + self._sum_y = torch.sum(y, dim=reduce_axis) + self._sum_y + self._sum_y_pred = torch.sum(y_pred, dim=reduce_axis) + self._sum_y_pred + + def reset(self, *args, **kwargs) -> None: + self._intersection: torch.Tensor = 0 + self._sum_y: torch.Tensor = 0 + self._sum_y_pred: torch.Tensor = 0 diff --git a/monai/metrics/functional/__init__.py b/monai/metrics/functional/__init__.py new file mode 100644 index 0000000000..0e72dfe5ab --- /dev/null +++ b/monai/metrics/functional/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .meandice import compute_meandice +from .rocauc import compute_roc_auc diff --git a/monai/metrics/meandice.py b/monai/metrics/functional/meandice.py similarity index 52% rename from monai/metrics/meandice.py rename to monai/metrics/functional/meandice.py index f02c874d93..cd4463d76e 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/functional/meandice.py @@ -15,94 +15,7 @@ from monai.networks.utils import one_hot -class DiceMetric: - """ - Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. - Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Axis N of `input` is expected to have logit predictions for each class rather than being image channels, - while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the - intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0, - this value should be small. The `include_background` class attribute can be set to False for an instance of - DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. - If the non-background segmentations are small compared to the total image size they can get overwhelmed by - the signal from the background so excluding it in such cases helps convergence. - - """ - - def __init__( - self, - include_background: bool = True, - to_onehot_y: bool = False, - mutually_exclusive: bool = False, - sigmoid: bool = False, - logit_thresh: float = 0.5, - reduction: str = "mean", - ): - super().__init__() - - if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: - raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.") - - self.include_background = include_background - self.to_onehot_y = to_onehot_y - self.mutually_exclusive = mutually_exclusive - self.sigmoid = sigmoid - self.logit_thresh = logit_thresh - self.reduction = reduction - - self.not_nans = None # keep track for valid elements in the batch - - def __call__(self, input: torch.Tensor, target: torch.Tensor): - - # compute dice (BxC) for each channel for each batch - f = compute_meandice( - y_pred=input, - y=target, - include_background=self.include_background, - to_onehot_y=self.to_onehot_y, - mutually_exclusive=self.mutually_exclusive, - sigmoid=self.sigmoid, - logit_thresh=self.logit_thresh, - ) - - # some dice elements might be Nan (if ground truth y was missing (zeros)) - # we need to account for it - - nans = torch.isnan(f) - not_nans = (~nans).float() - f[nans] = 0 - - t_zero = torch.zeros(1, device=f.device, dtype=torch.float) - - if self.reduction == "mean": - # 2 steps, first, mean by batch (accounting for nans), then by channel - - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average - - not_nans = not_nans.sum() - f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # channel average - - elif self.reduction == "sum": - not_nans = not_nans.sum() - f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction == "mean_batch": - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average - elif self.reduction == "sum_batch": - not_nans = not_nans.sum(dim=0) - f = f.sum(dim=0) # the batch sum - elif self.reduction == "none": - pass - else: - raise ValueError(f"reduction={self.reduction} is invalid.") - - self.not_nans = not_nans # preserve, since we may need it later to know how many elements were valid - - return f - - -def compute_meandice( +def _convert_predictions( y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, @@ -111,35 +24,7 @@ def compute_meandice( sigmoid: bool = False, logit_thresh: float = 0.5, ): - """Computes Dice score metric from full size Tensor and collects average. - - Args: - y_pred (torch.Tensor): input data to compute, typical segmentation model output. - it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. - y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. - example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32]. - alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly. - include_background: whether to skip Dice computation on the first channel of - the predicted output. Defaults to True. - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using - a combination of argmax and to_onehot. Defaults to False. - sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. - logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) - `y_pred` into a binary matrix. Defaults to 0.5. - - Returns: - Dice scores per batch and per class, (shape [batch_size, n_classes]). - - Note: - This method provides two options to convert `y_pred` into a binary matrix - (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``, - (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh`` - (optionally with a ``sigmoid`` function before thresholding). - - """ n_classes = y_pred.shape[1] - n_len = len(y_pred.shape) if sigmoid: y_pred = y_pred.float().sigmoid() @@ -176,12 +61,60 @@ def compute_meandice( ) y = y.float() y_pred = y_pred.float() + return y_pred, y + + +def compute_meandice( + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + to_onehot_y: bool = False, + mutually_exclusive: bool = False, + sigmoid: bool = False, + logit_thresh: float = 0.5, +): + """Computes Dice score metric from full size Tensor and collects average. + + Args: + y_pred : input data to compute, typical segmentation model output. + it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. + y: ground truth to compute mean dice metric, the first dim is batch. + example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32]. + alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly. + include_background: whether to skip Dice computation on the first channel of + the predicted output. Defaults to True. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. Defaults to False. + sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. + logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) + `y_pred` into a binary matrix. Defaults to 0.5. + + Returns: + Dice scores per batch and per class, (shape [batch_size, n_classes]). + + Note: + This method provides two options to convert `y_pred` into a binary matrix + (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``, + (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh`` + (optionally with a ``sigmoid`` function before thresholding). + """ + y_pred, y = _convert_predictions( + y_pred, + y, + include_background=include_background, + to_onehot_y=to_onehot_y, + mutually_exclusive=mutually_exclusive, + sigmoid=sigmoid, + logit_thresh=logit_thresh, + ) + n_len = len(y_pred.shape) # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, n_len)) intersection = torch.sum(y * y_pred, dim=reduce_axis) - y_o = torch.sum(y, reduce_axis) + y_o = torch.sum(y, dim=reduce_axis) y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o diff --git a/monai/metrics/functional/rocauc.py b/monai/metrics/functional/rocauc.py new file mode 100644 index 0000000000..9c00d89d7e --- /dev/null +++ b/monai/metrics/functional/rocauc.py @@ -0,0 +1,124 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import warnings +import numpy as np +from monai.networks.utils import one_hot + + +def _calculate(y, y_pred): + assert y.ndimension() == y_pred.ndimension() == 1 and len(y) == len( + y_pred + ), "y and y_pred must be 1 dimension data with same length." + assert y.unique().equal( + torch.tensor([0, 1], dtype=y.dtype, device=y.device) + ), "y values must be 0 or 1, can not be all 0 or all 1." + n = len(y) + indexes = y_pred.argsort() + y = y[indexes].cpu().numpy() + y_pred = y_pred[indexes].cpu().numpy() + nneg = auc = tmp_pos = tmp_neg = 0 + + for i in range(n): + y_i = y[i] + if i + 1 < n and y_pred[i] == y_pred[i + 1]: + tmp_pos += y_i + tmp_neg += 1 - y_i + continue + if tmp_pos + tmp_neg > 0: + tmp_pos += y_i + tmp_neg += 1 - y_i + nneg += tmp_neg + auc += tmp_pos * (nneg - tmp_neg / 2) + tmp_pos = tmp_neg = 0 + continue + if y_i == 1: + auc += nneg + else: + nneg += 1 + return auc / (nneg * (n - nneg)) + + +def compute_roc_auc( + y_pred: torch.Tensor, + y: torch.Tensor, + to_onehot_y: bool = False, + softmax: bool = False, + average: Optional[str] = "macro", +): + """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: + `sklearn.metrics.roc_auc_score `_. + + Args: + y_pred: input data to compute, typical classification model output. + it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. + y: ground truth to compute ROC AUC metric, the first dim is batch. + example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. + average (`macro|weighted|micro|None`): type of averaging performed if not binary + classification. Default is 'macro'. + + - 'macro': calculate metrics for each label, and find their unweighted mean. + this does not take label imbalance into account. + - 'weighted': calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - 'micro': calculate metrics globally by considering each element of the label + indicator matrix as a label. + - None: the scores for each class are returned. + + Note: + ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. + """ + y_pred_ndim = y_pred.ndimension() + y_ndim = y.ndimension() + if y_pred_ndim not in (1, 2): + raise ValueError("predictions should be of shape (batch_size, n_classes) or (batch_size, ).") + if y_ndim not in (1, 2): + raise ValueError("targets should be of shape (batch_size, n_classes) or (batch_size, ).") + if y_pred_ndim == 2 and y_pred.shape[1] == 1: + y_pred = y_pred.squeeze(dim=-1) + y_pred_ndim = 1 + if y_ndim == 2 and y.shape[1] == 1: + y = y.squeeze(dim=-1) + + if y_pred_ndim == 1: + if to_onehot_y: + warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.") + if softmax: + warnings.warn("y_pred has only one channel, softmax=True ignored.") + return _calculate(y, y_pred) + else: + n_classes = y_pred.shape[1] + if to_onehot_y: + y = one_hot(y, n_classes) + if softmax: + y_pred = y_pred.float().softmax(dim=1) + + assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." + + if average == "micro": + return _calculate(y.flatten(), y_pred.flatten()) + else: + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] + if average is None: + return auc_values + if average == "macro": + return np.mean(auc_values) + if average == "weighted": + weights = [sum(y_) for y_ in y] + return np.average(auc_values, weights=weights) + raise ValueError("unsupported average method.") diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py new file mode 100644 index 0000000000..eac3cbb8e9 --- /dev/null +++ b/monai/metrics/metric.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class Metric(ABC): + """ + Abstract base class for class based API for metrics + """ + + @abstractmethod + def __call__(self, *args, **kwargs) -> None: + """ + Compute metric + """ + raise NotImplementedError + + +class CumulativeMetric(Metric): + """ + Abstract base class for metrics which need to be computed over multiple + samples and need to store intermediate values. To be consistent with + the other metrics, the metric can be called after all samples + were added to compute the final result. + """ + + @abstractmethod + def add_sample(self, *args, **kwargs) -> None: + """ + Add a new sample for evaluation + """ + raise NotImplementedError + + @abstractmethod + def reset(self, *args, **kwargs) -> None: + """ + Reset internally saved values + """ + raise NotImplementedError diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index e8b235b57a..b9018364fb 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -12,114 +12,105 @@ from typing import Optional import torch -import warnings -import numpy as np -from monai.networks.utils import one_hot - - -def _calculate(y, y_pred): - assert y.ndimension() == y_pred.ndimension() == 1 and len(y) == len( - y_pred - ), "y and y_pred must be 1 dimension data with same length." - assert y.unique().equal( - torch.tensor([0, 1], dtype=y.dtype, device=y.device) - ), "y values must be 0 or 1, can not be all 0 or all 1." - n = len(y) - indexes = y_pred.argsort() - y = y[indexes].cpu().numpy() - y_pred = y_pred[indexes].cpu().numpy() - nneg = auc = tmp_pos = tmp_neg = 0 - - for i in range(n): - y_i = y[i] - if i + 1 < n and y_pred[i] == y_pred[i + 1]: - tmp_pos += y_i - tmp_neg += 1 - y_i - continue - if tmp_pos + tmp_neg > 0: - tmp_pos += y_i - tmp_neg += 1 - y_i - nneg += tmp_neg - auc += tmp_pos * (nneg - tmp_neg / 2) - tmp_pos = tmp_neg = 0 - continue - if y_i == 1: - auc += nneg - else: - nneg += 1 - return auc / (nneg * (n - nneg)) - - -def compute_roc_auc( - y_pred: torch.Tensor, - y: torch.Tensor, - to_onehot_y: bool = False, - softmax: bool = False, - average: Optional[str] = "macro", -): - """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: - `sklearn.metrics.roc_auc_score `_. - - Args: - y_pred (torch.Tensor): input data to compute, typical classification model output. - it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. - y (torch.Tensor): ground truth to compute ROC AUC metric, the first dim is batch. - example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. - average (`macro|weighted|micro|None`): type of averaging performed if not binary - classification. Default is 'macro'. - - - 'macro': calculate metrics for each label, and find their unweighted mean. - this does not take label imbalance into account. - - 'weighted': calculate metrics for each label, and find their average, - weighted by support (the number of true instances for each label). - - 'micro': calculate metrics globally by considering each element of the label - indicator matrix as a label. - - None: the scores for each class are returned. - - Note: - ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. +from monai.metrics.metric import CumulativeMetric, Metric +from monai.metrics.functional.rocauc import compute_roc_auc + + +class ROCOUC(Metric): + def __init__( + self, to_onehot_y: bool = False, softmax: bool = False, average: Optional[str] = "macro", + ): + """ + Args: + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. + average (`macro|weighted|micro|None`): type of averaging performed if not binary + classification. Default is 'macro'. + + - 'macro': calculate metrics for each label, and find their unweighted mean. + this does not take label imbalance into account. + - 'weighted': calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - 'micro': calculate metrics globally by considering each element of the label + indicator matrix as a label. + - None: the scores for each class are returned. + """ + super().__init__() + self.to_onehot_y = to_onehot_y + self.softmax = softmax + self.average = average + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> float: + """ + # TODO: add docstring + + Args: + y_pred: + y: + + Returns: + + """ + return compute_roc_auc(y_pred, y, to_onehot_y=self.to_onehot_y, softmax=self.softmax, average=self.average) + + +class CumulativeROCAUC(CumulativeMetric): + """ + Class API to compute ROCAUC by adding individual samples to the metric. """ - y_pred_ndim = y_pred.ndimension() - y_ndim = y.ndimension() - if y_pred_ndim not in (1, 2): - raise ValueError("predictions should be of shape (batch_size, n_classes) or (batch_size, ).") - if y_ndim not in (1, 2): - raise ValueError("targets should be of shape (batch_size, n_classes) or (batch_size, ).") - if y_pred_ndim == 2 and y_pred.shape[1] == 1: - y_pred = y_pred.squeeze(dim=-1) - y_pred_ndim = 1 - if y_ndim == 2 and y.shape[1] == 1: - y = y.squeeze(dim=-1) - - if y_pred_ndim == 1: - if to_onehot_y: - warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.") - if softmax: - warnings.warn("y_pred has only one channel, softmax=True ignored.") - return _calculate(y, y_pred) - else: - n_classes = y_pred.shape[1] - if to_onehot_y: - y = one_hot(y, n_classes) - if softmax: - y_pred = y_pred.float().softmax(dim=1) - - assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." - - if average == "micro": - return _calculate(y.flatten(), y_pred.flatten()) - else: - y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) - auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] - if average is None: - return auc_values - if average == "macro": - return np.mean(auc_values) - if average == "weighted": - weights = [sum(y_) for y_ in y] - return np.average(auc_values, weights=weights) - raise ValueError("unsupported average method.") + + def __init__( + self, to_onehot_y: bool = False, softmax: bool = False, average: Optional[str] = "macro", + ): + """ + Args: + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. + average (`macro|weighted|micro|None`): type of averaging performed if not binary + classification. Default is 'macro'. + + - 'macro': calculate metrics for each label, and find their unweighted mean. + this does not take label imbalance into account. + - 'weighted': calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - 'micro': calculate metrics globally by considering each element of the label + indicator matrix as a label. + - None: the scores for each class are returned. + """ + super().__init__() + self.to_onehot_y = to_onehot_y + self.softmax = softmax + self.average = average + + self.y_pred_stored = [] + self.y_stored = [] + + def __call__(self) -> float: + """ + Compute the metric result from previously added sampled. + """ + y_pred = torch.cat(self.y_pred_stored) + y = torch.cat(self.y_stored) + return compute_roc_auc(y_pred, y, to_onehot_y=self.to_onehot_y, softmax=self.softmax, average=self.average) + + def add_sample(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: + """ + # TODO: add correct docstring + + Args: + y_pred: + y: + + Returns: + + """ + self.y_pred_stored.append(y_pred) + self.y_stored.append(y) + + def reset(self) -> None: + """ + Reset internal states for new computation + """ + self.y_pred_stored = [] + self.y_stored = []