From f4ae590a9efc2159898257526d0a1c1cb8e4ca2a Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 14 Jun 2020 19:35:27 +0200 Subject: [PATCH 1/5] create metrics.functional and add new files --- monai/metrics/__init__.py | 3 +- monai/metrics/base.py | 10 ++ monai/metrics/functional/__init__.py | 13 +++ monai/metrics/{ => functional}/meandice.py | 0 monai/metrics/functional/rocauc.py | 125 +++++++++++++++++++++ monai/metrics/rocauc.py | 115 ------------------- 6 files changed, 149 insertions(+), 117 deletions(-) create mode 100644 monai/metrics/base.py create mode 100644 monai/metrics/functional/__init__.py rename monai/metrics/{ => functional}/meandice.py (100%) create mode 100644 monai/metrics/functional/rocauc.py diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 0e72dfe5ab..dc20b47670 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -9,5 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .meandice import compute_meandice -from .rocauc import compute_roc_auc +from monai.metrics.functional import * diff --git a/monai/metrics/base.py b/monai/metrics/base.py new file mode 100644 index 0000000000..d0044e3563 --- /dev/null +++ b/monai/metrics/base.py @@ -0,0 +1,10 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/metrics/functional/__init__.py b/monai/metrics/functional/__init__.py new file mode 100644 index 0000000000..0e72dfe5ab --- /dev/null +++ b/monai/metrics/functional/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .meandice import compute_meandice +from .rocauc import compute_roc_auc diff --git a/monai/metrics/meandice.py b/monai/metrics/functional/meandice.py similarity index 100% rename from monai/metrics/meandice.py rename to monai/metrics/functional/meandice.py diff --git a/monai/metrics/functional/rocauc.py b/monai/metrics/functional/rocauc.py new file mode 100644 index 0000000000..e8b235b57a --- /dev/null +++ b/monai/metrics/functional/rocauc.py @@ -0,0 +1,125 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import warnings +import numpy as np +from monai.networks.utils import one_hot + + +def _calculate(y, y_pred): + assert y.ndimension() == y_pred.ndimension() == 1 and len(y) == len( + y_pred + ), "y and y_pred must be 1 dimension data with same length." + assert y.unique().equal( + torch.tensor([0, 1], dtype=y.dtype, device=y.device) + ), "y values must be 0 or 1, can not be all 0 or all 1." + n = len(y) + indexes = y_pred.argsort() + y = y[indexes].cpu().numpy() + y_pred = y_pred[indexes].cpu().numpy() + nneg = auc = tmp_pos = tmp_neg = 0 + + for i in range(n): + y_i = y[i] + if i + 1 < n and y_pred[i] == y_pred[i + 1]: + tmp_pos += y_i + tmp_neg += 1 - y_i + continue + if tmp_pos + tmp_neg > 0: + tmp_pos += y_i + tmp_neg += 1 - y_i + nneg += tmp_neg + auc += tmp_pos * (nneg - tmp_neg / 2) + tmp_pos = tmp_neg = 0 + continue + if y_i == 1: + auc += nneg + else: + nneg += 1 + return auc / (nneg * (n - nneg)) + + +def compute_roc_auc( + y_pred: torch.Tensor, + y: torch.Tensor, + to_onehot_y: bool = False, + softmax: bool = False, + average: Optional[str] = "macro", +): + """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: + `sklearn.metrics.roc_auc_score `_. + + Args: + y_pred (torch.Tensor): input data to compute, typical classification model output. + it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. + y (torch.Tensor): ground truth to compute ROC AUC metric, the first dim is batch. + example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. + average (`macro|weighted|micro|None`): type of averaging performed if not binary + classification. Default is 'macro'. + + - 'macro': calculate metrics for each label, and find their unweighted mean. + this does not take label imbalance into account. + - 'weighted': calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - 'micro': calculate metrics globally by considering each element of the label + indicator matrix as a label. + - None: the scores for each class are returned. + + Note: + ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. + + """ + y_pred_ndim = y_pred.ndimension() + y_ndim = y.ndimension() + if y_pred_ndim not in (1, 2): + raise ValueError("predictions should be of shape (batch_size, n_classes) or (batch_size, ).") + if y_ndim not in (1, 2): + raise ValueError("targets should be of shape (batch_size, n_classes) or (batch_size, ).") + if y_pred_ndim == 2 and y_pred.shape[1] == 1: + y_pred = y_pred.squeeze(dim=-1) + y_pred_ndim = 1 + if y_ndim == 2 and y.shape[1] == 1: + y = y.squeeze(dim=-1) + + if y_pred_ndim == 1: + if to_onehot_y: + warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.") + if softmax: + warnings.warn("y_pred has only one channel, softmax=True ignored.") + return _calculate(y, y_pred) + else: + n_classes = y_pred.shape[1] + if to_onehot_y: + y = one_hot(y, n_classes) + if softmax: + y_pred = y_pred.float().softmax(dim=1) + + assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." + + if average == "micro": + return _calculate(y.flatten(), y_pred.flatten()) + else: + y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) + auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] + if average is None: + return auc_values + if average == "macro": + return np.mean(auc_values) + if average == "weighted": + weights = [sum(y_) for y_ in y] + return np.average(auc_values, weights=weights) + raise ValueError("unsupported average method.") diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index e8b235b57a..d0044e3563 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -8,118 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Optional - -import torch -import warnings -import numpy as np -from monai.networks.utils import one_hot - - -def _calculate(y, y_pred): - assert y.ndimension() == y_pred.ndimension() == 1 and len(y) == len( - y_pred - ), "y and y_pred must be 1 dimension data with same length." - assert y.unique().equal( - torch.tensor([0, 1], dtype=y.dtype, device=y.device) - ), "y values must be 0 or 1, can not be all 0 or all 1." - n = len(y) - indexes = y_pred.argsort() - y = y[indexes].cpu().numpy() - y_pred = y_pred[indexes].cpu().numpy() - nneg = auc = tmp_pos = tmp_neg = 0 - - for i in range(n): - y_i = y[i] - if i + 1 < n and y_pred[i] == y_pred[i + 1]: - tmp_pos += y_i - tmp_neg += 1 - y_i - continue - if tmp_pos + tmp_neg > 0: - tmp_pos += y_i - tmp_neg += 1 - y_i - nneg += tmp_neg - auc += tmp_pos * (nneg - tmp_neg / 2) - tmp_pos = tmp_neg = 0 - continue - if y_i == 1: - auc += nneg - else: - nneg += 1 - return auc / (nneg * (n - nneg)) - - -def compute_roc_auc( - y_pred: torch.Tensor, - y: torch.Tensor, - to_onehot_y: bool = False, - softmax: bool = False, - average: Optional[str] = "macro", -): - """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: - `sklearn.metrics.roc_auc_score `_. - - Args: - y_pred (torch.Tensor): input data to compute, typical classification model output. - it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. - y (torch.Tensor): ground truth to compute ROC AUC metric, the first dim is batch. - example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. - average (`macro|weighted|micro|None`): type of averaging performed if not binary - classification. Default is 'macro'. - - - 'macro': calculate metrics for each label, and find their unweighted mean. - this does not take label imbalance into account. - - 'weighted': calculate metrics for each label, and find their average, - weighted by support (the number of true instances for each label). - - 'micro': calculate metrics globally by considering each element of the label - indicator matrix as a label. - - None: the scores for each class are returned. - - Note: - ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. - - """ - y_pred_ndim = y_pred.ndimension() - y_ndim = y.ndimension() - if y_pred_ndim not in (1, 2): - raise ValueError("predictions should be of shape (batch_size, n_classes) or (batch_size, ).") - if y_ndim not in (1, 2): - raise ValueError("targets should be of shape (batch_size, n_classes) or (batch_size, ).") - if y_pred_ndim == 2 and y_pred.shape[1] == 1: - y_pred = y_pred.squeeze(dim=-1) - y_pred_ndim = 1 - if y_ndim == 2 and y.shape[1] == 1: - y = y.squeeze(dim=-1) - - if y_pred_ndim == 1: - if to_onehot_y: - warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.") - if softmax: - warnings.warn("y_pred has only one channel, softmax=True ignored.") - return _calculate(y, y_pred) - else: - n_classes = y_pred.shape[1] - if to_onehot_y: - y = one_hot(y, n_classes) - if softmax: - y_pred = y_pred.float().softmax(dim=1) - - assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." - - if average == "micro": - return _calculate(y.flatten(), y_pred.flatten()) - else: - y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) - auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] - if average is None: - return auc_values - if average == "macro": - return np.mean(auc_values) - if average == "weighted": - weights = [sum(y_) for y_ in y] - return np.average(auc_values, weights=weights) - raise ValueError("unsupported average method.") From a3d63796dc015c389c783d15ef6ba200256b1dad Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Sun, 14 Jun 2020 19:58:03 +0200 Subject: [PATCH 2/5] rename base, add minimal example --- monai/metrics/__init__.py | 2 ++ monai/metrics/base.py | 10 ------- monai/metrics/metric.py | 50 ++++++++++++++++++++++++++++++++ monai/metrics/rocauc.py | 61 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 10 deletions(-) delete mode 100644 monai/metrics/base.py create mode 100644 monai/metrics/metric.py diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index dc20b47670..f0fefb86d1 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -10,3 +10,5 @@ # limitations under the License. from monai.metrics.functional import * +from monai.metrics.metric import Metric, PartialMetric + diff --git a/monai/metrics/base.py b/monai/metrics/base.py deleted file mode 100644 index d0044e3563..0000000000 --- a/monai/metrics/base.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py new file mode 100644 index 0000000000..61718e6f12 --- /dev/null +++ b/monai/metrics/metric.py @@ -0,0 +1,50 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + + +class Metric(ABC): + """ + Abstract base class for class based API for metrics + """ + + @abstractmethod + def __call__(self, *args, **kwargs) -> None: + raise NotImplementedError + + +class PartialMetric(Metric): + """ + Abstract base class for metrics which need to be computed over multiple + samples and need to store intermediate values. + """ + + @abstractmethod + def add_sample(self, *args, **kwargs) -> None: + """ + Add a new sample for evaluation + """ + raise NotImplementedError + + @abstractmethod + def evaluate(self, *args, **kwargs) -> None: + """ + Evaluate the metric from added samples + """ + raise NotImplementedError + + @abstractmethod + def reset(self, *args, **kwargs) -> None: + """ + Reset internally saved values + """ + raise NotImplementedError diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index d0044e3563..578faf9d8e 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -8,3 +8,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from monai.metrics.metric import PartialMetric +from monai.metrics.functional.rocauc import compute_roc_auc + + +class ROCAUC(PartialMetric): + """ + Class based API to either compute ROCAUC directly or compute it over + indivudally added samples + """ + + def __init__(self, + to_onehot_y: bool = False, + softmax: bool = False, + average: Optional[str] = "macro", + ): + """ + Args: + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. + average (`macro|weighted|micro|None`): type of averaging performed if not binary + classification. Default is 'macro'. + + - 'macro': calculate metrics for each label, and find their unweighted mean. + this does not take label imbalance into account. + - 'weighted': calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - 'micro': calculate metrics globally by considering each element of the label + indicator matrix as a label. + - None: the scores for each class are returned. + """ + super().__init__() + self.to_onehot_y = to_onehot_y + self.softmax = softmax + self.average = average + + self.y_pred_stored = [] + self.y_stored = [] + + def __call__(self, + y_pred: torch.Tensor, + y: torch.Tensor, + ): + return compute_roc_auc(y_pred, y, + to_onehot_y=to_onehot_y, + softmax=softmax, + average=average, + ) + + def add_sample(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: + self.y_pred_stored.append(y_pred) + self.y_stored.append(y) + + def evaluate(self) -> float: + y_pred = torch.cat(self.y_pred_stored) + y = torch.cat(self.y_stored) + return self(y_pred, y) + + def reset(self) -> None: + self.y_pred_stored = [] + self.y_stored = [] From 96fe6a46f2cfac8ab972d1a6c40a7bfad3484f71 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Thu, 18 Jun 2020 22:17:20 +0200 Subject: [PATCH 3/5] move dice class API to correct file --- monai/metrics/__init__.py | 2 + monai/metrics/dice.py | 101 +++++++++++++++++++++++++++ monai/metrics/functional/meandice.py | 87 ----------------------- 3 files changed, 103 insertions(+), 87 deletions(-) create mode 100644 monai/metrics/dice.py diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index bffcb74535..138fcfb22a 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -11,3 +11,5 @@ from monai.metrics.functional import * from monai.metrics.metric import Metric, PartialMetric +from monai.metrics.dice import DiceMetric +from monai.metrics.rocauc import ROCAUC diff --git a/monai/metrics/dice.py b/monai/metrics/dice.py new file mode 100644 index 0000000000..a96b6b2a37 --- /dev/null +++ b/monai/metrics/dice.py @@ -0,0 +1,101 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.metrics.functional.meandice import compute_meandice + + +class DiceMetric: + """ + Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. + Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). + Axis N of `input` is expected to have logit predictions for each class rather than being image channels, + while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the + intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0, + this value should be small. The `include_background` class attribute can be set to False for an instance of + DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. + If the non-background segmentations are small compared to the total image size they can get overwhelmed by + the signal from the background so excluding it in such cases helps convergence. + + """ + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + mutually_exclusive: bool = False, + sigmoid: bool = False, + logit_thresh: float = 0.5, + reduction: str = "mean", + ): + super().__init__() + + if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: + raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.") + + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.mutually_exclusive = mutually_exclusive + self.sigmoid = sigmoid + self.logit_thresh = logit_thresh + self.reduction = reduction + + self.not_nans = None # keep track for valid elements in the batch + + def __call__(self, input: torch.Tensor, target: torch.Tensor): + + # compute dice (BxC) for each channel for each batch + f = compute_meandice( + y_pred=input, + y=target, + include_background=self.include_background, + to_onehot_y=self.to_onehot_y, + mutually_exclusive=self.mutually_exclusive, + sigmoid=self.sigmoid, + logit_thresh=self.logit_thresh, + ) + + # some dice elements might be Nan (if ground truth y was missing (zeros)) + # we need to account for it + + nans = torch.isnan(f) + not_nans = (~nans).float() + f[nans] = 0 + + t_zero = torch.zeros(1, device=f.device, dtype=torch.float) + + if self.reduction == "mean": + # 2 steps, first, mean by batch (accounting for nans), then by channel + + not_nans = not_nans.sum(dim=0) + f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + + not_nans = not_nans.sum() + f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # channel average + + elif self.reduction == "sum": + not_nans = not_nans.sum() + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == "mean_batch": + not_nans = not_nans.sum(dim=0) + f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average + elif self.reduction == "sum_batch": + not_nans = not_nans.sum(dim=0) + f = f.sum(dim=0) # the batch sum + elif self.reduction == "none": + pass + else: + raise ValueError(f"reduction={self.reduction} is invalid.") + + self.not_nans = not_nans # preserve, since we may need it later to know how many elements were valid + + return f diff --git a/monai/metrics/functional/meandice.py b/monai/metrics/functional/meandice.py index f02c874d93..692048f1bc 100644 --- a/monai/metrics/functional/meandice.py +++ b/monai/metrics/functional/meandice.py @@ -15,93 +15,6 @@ from monai.networks.utils import one_hot -class DiceMetric: - """ - Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. - Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). - Axis N of `input` is expected to have logit predictions for each class rather than being image channels, - while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the - intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0, - this value should be small. The `include_background` class attribute can be set to False for an instance of - DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. - If the non-background segmentations are small compared to the total image size they can get overwhelmed by - the signal from the background so excluding it in such cases helps convergence. - - """ - - def __init__( - self, - include_background: bool = True, - to_onehot_y: bool = False, - mutually_exclusive: bool = False, - sigmoid: bool = False, - logit_thresh: float = 0.5, - reduction: str = "mean", - ): - super().__init__() - - if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: - raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.") - - self.include_background = include_background - self.to_onehot_y = to_onehot_y - self.mutually_exclusive = mutually_exclusive - self.sigmoid = sigmoid - self.logit_thresh = logit_thresh - self.reduction = reduction - - self.not_nans = None # keep track for valid elements in the batch - - def __call__(self, input: torch.Tensor, target: torch.Tensor): - - # compute dice (BxC) for each channel for each batch - f = compute_meandice( - y_pred=input, - y=target, - include_background=self.include_background, - to_onehot_y=self.to_onehot_y, - mutually_exclusive=self.mutually_exclusive, - sigmoid=self.sigmoid, - logit_thresh=self.logit_thresh, - ) - - # some dice elements might be Nan (if ground truth y was missing (zeros)) - # we need to account for it - - nans = torch.isnan(f) - not_nans = (~nans).float() - f[nans] = 0 - - t_zero = torch.zeros(1, device=f.device, dtype=torch.float) - - if self.reduction == "mean": - # 2 steps, first, mean by batch (accounting for nans), then by channel - - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average - - not_nans = not_nans.sum() - f = torch.where(not_nans > 0, f.sum() / not_nans, t_zero) # channel average - - elif self.reduction == "sum": - not_nans = not_nans.sum() - f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction == "mean_batch": - not_nans = not_nans.sum(dim=0) - f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average - elif self.reduction == "sum_batch": - not_nans = not_nans.sum(dim=0) - f = f.sum(dim=0) # the batch sum - elif self.reduction == "none": - pass - else: - raise ValueError(f"reduction={self.reduction} is invalid.") - - self.not_nans = not_nans # preserve, since we may need it later to know how many elements were valid - - return f - - def compute_meandice( y_pred: torch.Tensor, y: torch.Tensor, From 97e65cffde751d3658cf48aefef9795eb9fbe319 Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Thu, 18 Jun 2020 23:19:27 +0200 Subject: [PATCH 4/5] add cumulative dice metric --- monai/metrics/__init__.py | 6 +- monai/metrics/dice.py | 90 +++++++++++++++++++++++++-- monai/metrics/functional/meandice.py | 80 +++++++++++++++--------- monai/metrics/functional/rocauc.py | 5 +- monai/metrics/metric.py | 16 +++-- monai/metrics/rocauc.py | 93 +++++++++++++++++++++------- 6 files changed, 217 insertions(+), 73 deletions(-) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 138fcfb22a..18ffe335c3 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -10,6 +10,6 @@ # limitations under the License. from monai.metrics.functional import * -from monai.metrics.metric import Metric, PartialMetric -from monai.metrics.dice import DiceMetric -from monai.metrics.rocauc import ROCAUC +from monai.metrics.metric import Metric, CumulativeMetric +from monai.metrics.dice import Dice +from monai.metrics.rocauc import CumulativeROCAUC diff --git a/monai/metrics/dice.py b/monai/metrics/dice.py index a96b6b2a37..4d853b4971 100644 --- a/monai/metrics/dice.py +++ b/monai/metrics/dice.py @@ -11,10 +11,11 @@ import torch -from monai.metrics.functional.meandice import compute_meandice +from monai.metrics.functional.meandice import compute_meandice, _convert_predictions +from monai.metrics.metric import Metric, CumulativeMetric -class DiceMetric: +class Dice(Metric): """ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). @@ -25,7 +26,6 @@ class DiceMetric: DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. If the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. - """ def __init__( @@ -37,6 +37,17 @@ def __init__( logit_thresh: float = 0.5, reduction: str = "mean", ): + """ + Args: + include_background: whether to skip Dice computation on the first channel of + the predicted output. Defaults to True. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. Defaults to False. + sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. + logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) + `y_pred` into a binary matrix. Defaults to 0.5. + """ super().__init__() if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: @@ -51,7 +62,7 @@ def __init__( self.not_nans = None # keep track for valid elements in the batch - def __call__(self, input: torch.Tensor, target: torch.Tensor): + def __call__(self, input: torch.Tensor, target: torch.Tensor) -> float: # compute dice (BxC) for each channel for each batch f = compute_meandice( @@ -99,3 +110,74 @@ def __call__(self, input: torch.Tensor, target: torch.Tensor): self.not_nans = not_nans # preserve, since we may need it later to know how many elements were valid return f + + +class CumulativeDice(CumulativeMetric): + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + mutually_exclusive: bool = False, + sigmoid: bool = False, + logit_thresh: float = 0.5, + reduction: str = "mean", + ): + """ + Args: + include_background: whether to skip Dice computation on the first channel of + the predicted output. Defaults to True. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. Defaults to False. + sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. + logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) + `y_pred` into a binary matrix. Defaults to 0.5. + """ + super().__init__() + + if reduction not in ["none", "mean", "sum", "mean_batch", "sum_batch"]: + raise ValueError(f"reduction={reduction} is invalid. Valid options are: none, mean or sum.") + + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.mutually_exclusive = mutually_exclusive + self.sigmoid = sigmoid + self.logit_thresh = logit_thresh + self.reduction = reduction + + self._intersection: torch.Tensor = 0 + self._sum_y: torch.Tensor = 0 + self._sum_y_pred: torch.Tensor = 0 + + def __call__(self): + denominator = self._sum_y + self._sum_y_pred + f = torch.where( + self._sum_y > 0, + (2.0 * self._intersection) / denominator, + torch.tensor(float("nan"), device=self._sum_y.device), + ) + return f # returns array of Dice shape: [Batch, n_classes] + + def add_sample(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: + y_pred, y = _convert_predictions( + y_pred, + y, + include_background=self.include_background, + to_onehot_y=self.to_onehot_y, + mutually_exclusive=self.mutually_exclusive, + sigmoid=self.sigmoid, + logit_thresh=self.logit_thresh, + ) + n_len = len(y_pred.shape) + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis = list(range(2, n_len)) + + self._intersection = torch.sum(y * y_pred, dim=reduce_axis) + self._intersection + self._sum_y = torch.sum(y, dim=reduce_axis) + self._sum_y + self._sum_y_pred = torch.sum(y_pred, dim=reduce_axis) + self._sum_y_pred + + def reset(self, *args, **kwargs) -> None: + self._intersection: torch.Tensor = 0 + self._sum_y: torch.Tensor = 0 + self._sum_y_pred: torch.Tensor = 0 diff --git a/monai/metrics/functional/meandice.py b/monai/metrics/functional/meandice.py index 692048f1bc..cd4463d76e 100644 --- a/monai/metrics/functional/meandice.py +++ b/monai/metrics/functional/meandice.py @@ -15,7 +15,7 @@ from monai.networks.utils import one_hot -def compute_meandice( +def _convert_predictions( y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, @@ -24,35 +24,7 @@ def compute_meandice( sigmoid: bool = False, logit_thresh: float = 0.5, ): - """Computes Dice score metric from full size Tensor and collects average. - - Args: - y_pred (torch.Tensor): input data to compute, typical segmentation model output. - it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. - y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. - example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32]. - alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly. - include_background: whether to skip Dice computation on the first channel of - the predicted output. Defaults to True. - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using - a combination of argmax and to_onehot. Defaults to False. - sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. - logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) - `y_pred` into a binary matrix. Defaults to 0.5. - - Returns: - Dice scores per batch and per class, (shape [batch_size, n_classes]). - - Note: - This method provides two options to convert `y_pred` into a binary matrix - (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``, - (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh`` - (optionally with a ``sigmoid`` function before thresholding). - - """ n_classes = y_pred.shape[1] - n_len = len(y_pred.shape) if sigmoid: y_pred = y_pred.float().sigmoid() @@ -89,12 +61,60 @@ def compute_meandice( ) y = y.float() y_pred = y_pred.float() + return y_pred, y + + +def compute_meandice( + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + to_onehot_y: bool = False, + mutually_exclusive: bool = False, + sigmoid: bool = False, + logit_thresh: float = 0.5, +): + """Computes Dice score metric from full size Tensor and collects average. + + Args: + y_pred : input data to compute, typical segmentation model output. + it must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. + y: ground truth to compute mean dice metric, the first dim is batch. + example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32]. + alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly. + include_background: whether to skip Dice computation on the first channel of + the predicted output. Defaults to True. + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + mutually_exclusive: if True, `y_pred` will be converted into a binary matrix using + a combination of argmax and to_onehot. Defaults to False. + sigmoid: whether to add sigmoid function to y_pred before computation. Defaults to False. + logit_thresh: the threshold value used to convert (after sigmoid if `sigmoid=True`) + `y_pred` into a binary matrix. Defaults to 0.5. + + Returns: + Dice scores per batch and per class, (shape [batch_size, n_classes]). + + Note: + This method provides two options to convert `y_pred` into a binary matrix + (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``, + (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh`` + (optionally with a ``sigmoid`` function before thresholding). + """ + y_pred, y = _convert_predictions( + y_pred, + y, + include_background=include_background, + to_onehot_y=to_onehot_y, + mutually_exclusive=mutually_exclusive, + sigmoid=sigmoid, + logit_thresh=logit_thresh, + ) + n_len = len(y_pred.shape) # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, n_len)) intersection = torch.sum(y * y_pred, dim=reduce_axis) - y_o = torch.sum(y, reduce_axis) + y_o = torch.sum(y, dim=reduce_axis) y_pred_o = torch.sum(y_pred, dim=reduce_axis) denominator = y_o + y_pred_o diff --git a/monai/metrics/functional/rocauc.py b/monai/metrics/functional/rocauc.py index e8b235b57a..9c00d89d7e 100644 --- a/monai/metrics/functional/rocauc.py +++ b/monai/metrics/functional/rocauc.py @@ -62,9 +62,9 @@ def compute_roc_auc( sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_. Args: - y_pred (torch.Tensor): input data to compute, typical classification model output. + y_pred: input data to compute, typical classification model output. it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. - y (torch.Tensor): ground truth to compute ROC AUC metric, the first dim is batch. + y: ground truth to compute ROC AUC metric, the first dim is batch. example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. @@ -81,7 +81,6 @@ def compute_roc_auc( Note: ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. - """ y_pred_ndim = y_pred.ndimension() y_ndim = y.ndimension() diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py index 61718e6f12..eac3cbb8e9 100644 --- a/monai/metrics/metric.py +++ b/monai/metrics/metric.py @@ -19,13 +19,18 @@ class Metric(ABC): @abstractmethod def __call__(self, *args, **kwargs) -> None: + """ + Compute metric + """ raise NotImplementedError -class PartialMetric(Metric): +class CumulativeMetric(Metric): """ Abstract base class for metrics which need to be computed over multiple - samples and need to store intermediate values. + samples and need to store intermediate values. To be consistent with + the other metrics, the metric can be called after all samples + were added to compute the final result. """ @abstractmethod @@ -34,13 +39,6 @@ def add_sample(self, *args, **kwargs) -> None: Add a new sample for evaluation """ raise NotImplementedError - - @abstractmethod - def evaluate(self, *args, **kwargs) -> None: - """ - Evaluate the metric from added samples - """ - raise NotImplementedError @abstractmethod def reset(self, *args, **kwargs) -> None: diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 578faf9d8e..b9018364fb 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -9,21 +9,60 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.metrics.metric import PartialMetric +from typing import Optional + +import torch + +from monai.metrics.metric import CumulativeMetric, Metric from monai.metrics.functional.rocauc import compute_roc_auc -class ROCAUC(PartialMetric): +class ROCOUC(Metric): + def __init__( + self, to_onehot_y: bool = False, softmax: bool = False, average: Optional[str] = "macro", + ): + """ + Args: + to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. + softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. + average (`macro|weighted|micro|None`): type of averaging performed if not binary + classification. Default is 'macro'. + + - 'macro': calculate metrics for each label, and find their unweighted mean. + this does not take label imbalance into account. + - 'weighted': calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - 'micro': calculate metrics globally by considering each element of the label + indicator matrix as a label. + - None: the scores for each class are returned. + """ + super().__init__() + self.to_onehot_y = to_onehot_y + self.softmax = softmax + self.average = average + + def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> float: + """ + # TODO: add docstring + + Args: + y_pred: + y: + + Returns: + + """ + return compute_roc_auc(y_pred, y, to_onehot_y=self.to_onehot_y, softmax=self.softmax, average=self.average) + + +class CumulativeROCAUC(CumulativeMetric): """ - Class based API to either compute ROCAUC directly or compute it over - indivudally added samples + Class API to compute ROCAUC by adding individual samples to the metric. """ - def __init__(self, - to_onehot_y: bool = False, - softmax: bool = False, - average: Optional[str] = "macro", - ): + def __init__( + self, to_onehot_y: bool = False, softmax: bool = False, average: Optional[str] = "macro", + ): """ Args: to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. @@ -43,29 +82,35 @@ def __init__(self, self.to_onehot_y = to_onehot_y self.softmax = softmax self.average = average - + self.y_pred_stored = [] self.y_stored = [] - def __call__(self, - y_pred: torch.Tensor, - y: torch.Tensor, - ): - return compute_roc_auc(y_pred, y, - to_onehot_y=to_onehot_y, - softmax=softmax, - average=average, - ) + def __call__(self) -> float: + """ + Compute the metric result from previously added sampled. + """ + y_pred = torch.cat(self.y_pred_stored) + y = torch.cat(self.y_stored) + return compute_roc_auc(y_pred, y, to_onehot_y=self.to_onehot_y, softmax=self.softmax, average=self.average) def add_sample(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: + """ + # TODO: add correct docstring + + Args: + y_pred: + y: + + Returns: + + """ self.y_pred_stored.append(y_pred) self.y_stored.append(y) - def evaluate(self) -> float: - y_pred = torch.cat(self.y_pred_stored) - y = torch.cat(self.y_stored) - return self(y_pred, y) - def reset(self) -> None: + """ + Reset internal states for new computation + """ self.y_pred_stored = [] self.y_stored = [] From 6eb9a290eca3609e3d31b1d43b0e11c21603789b Mon Sep 17 00:00:00 2001 From: mibaumgartner Date: Thu, 18 Jun 2020 23:24:58 +0200 Subject: [PATCH 5/5] update __init__ --- monai/metrics/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 18ffe335c3..23432b32bb 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -11,5 +11,5 @@ from monai.metrics.functional import * from monai.metrics.metric import Metric, CumulativeMetric -from monai.metrics.dice import Dice -from monai.metrics.rocauc import CumulativeROCAUC +from monai.metrics.dice import Dice, CumulativeDice +from monai.metrics.rocauc import ROCOUC, CumulativeROCAUC