diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index b543ab3f40..6e48aeabe6 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -10,6 +10,11 @@ Metrics ------ .. autofunction:: compute_froc_score +`Metric` +-------- +.. autoclass:: Metric + :members: + `Mean Dice` ----------- .. autofunction:: compute_meandice @@ -21,6 +26,9 @@ Metrics -------------------------- .. autofunction:: compute_roc_auc +.. autoclass:: ROCAUCMetric + :members: + `Confusion matrix` ------------------ .. autofunction:: get_confusion_matrix diff --git a/monai/config/__init__.py b/monai/config/__init__.py index f1c7707d1f..fed8e49771 100644 --- a/monai/config/__init__.py +++ b/monai/config/__init__.py @@ -18,4 +18,4 @@ print_gpu_info, print_system_info, ) -from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayTensor +from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayTensor, TensorOrList diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py index daa9b10052..375ae460b2 100644 --- a/monai/config/type_definitions.py +++ b/monai/config/type_definitions.py @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Hashable, Iterable, TypeVar, Union +from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union import numpy as np import torch -__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor"] +__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor", "TensorOrList"] """Commonly used concepts This module provides naming and type specifications for commonly used concepts @@ -55,6 +55,7 @@ container must be iterable. """ + DtypeLike = Union[ np.dtype, type, @@ -67,3 +68,10 @@ # Generic type which can represent either a numpy.ndarray or a torch.Tensor # Unlike Union can create a dependence between parameter(s) / return(s) NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor) + + +TensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]] +"""TensorOrList + +The TensorOrList type is used for defining `batch-first Tensor` or `list of channel-first Tensor`. +""" diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index 551fd29199..f33151f832 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Callable, Union import torch from monai.handlers.iteration_metric import IterationMetric -from monai.metrics import ConfusionMatrixMetric, compute_confusion_matrix_metric -from monai.metrics.utils import MetricReduction, do_metric_reduction +from monai.metrics import ConfusionMatrixMetric +from monai.metrics.utils import MetricReduction class ConfusionMatrix(IterationMetric): @@ -30,6 +30,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -47,6 +48,9 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. See also: :py:meth:`monai.metrics.confusion_matrix` @@ -55,7 +59,7 @@ def __init__( include_background=include_background, metric_name=metric_name, compute_sample=False, - reduction=MetricReduction.NONE, + reduction=reduction, ) self.metric_name = metric_name super().__init__( @@ -64,7 +68,3 @@ def __init__( device=device, save_details=save_details, ) - - def _reduce(self, scores) -> Any: - confusion_matrix, _ = do_metric_reduction(scores, MetricReduction.MEAN) - return compute_confusion_matrix_metric(self.metric_name, confusion_matrix) diff --git a/monai/handlers/hausdorff_distance.py b/monai/handlers/hausdorff_distance.py index 042a587852..713e1c8d3a 100644 --- a/monai/handlers/hausdorff_distance.py +++ b/monai/handlers/hausdorff_distance.py @@ -32,6 +32,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -48,15 +49,17 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: hausdorff distance of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. """ - super().__init__(output_transform, device=device) metric_fn = HausdorffDistanceMetric( include_background=include_background, distance_metric=distance_metric, percentile=percentile, directed=directed, - reduction=MetricReduction.NONE, + reduction=reduction, ) super().__init__( metric_fn=metric_fn, diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 434dd483ed..a0428d80c4 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -14,11 +14,11 @@ import torch from monai.handlers.utils import evenly_divisible_all_gather -from monai.metrics import do_metric_reduction -from monai.utils import MetricReduction, exact_version, optional_import +from monai.metrics import Metric +from monai.utils import exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") -Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") +IgniteMetric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.4", exact_version, "reinit__is_reduced") if TYPE_CHECKING: from ignite.engine import Engine @@ -26,7 +26,7 @@ Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") -class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to optional_import +class IterationMetric(IgniteMetric): # type: ignore[valid-type, misc] # due to optional_import """ Class for metrics that should be computed on every iteration and compute final results when epoch completed. Similar to the `EpochMetric` in ignite: @@ -46,7 +46,7 @@ class IterationMetric(Metric): # type: ignore[valid-type, misc] # due to option def __init__( self, - metric_fn: Callable, + metric_fn: Metric, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, @@ -78,19 +78,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output - def _compute(y_pred, y): - if isinstance(y_pred, torch.Tensor): - y_pred = y_pred.detach() - if isinstance(y, torch.Tensor): - y = y.detach() - score = self.metric_fn(y_pred, y) - return score[0] if isinstance(score, (tuple, list)) else score - - if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)): - # if y_pred or y is a list of channel-first data, add batch dim and compute metric, then concat the scores - score = torch.cat([_compute(p_.unsqueeze(0), y_.unsqueeze(0)) for p_, y_ in zip(y_pred, y)], dim=0) - else: - score = _compute(y_pred, y) + score = self.metric_fn(y_pred, y) self._scores.append(score.to(self._device)) def compute(self) -> Any: @@ -117,6 +105,7 @@ def compute(self) -> Any: if idist.get_rank() == 0: # run compute_fn on zero rank only result = self._reduce(_scores) + result = result[0] if isinstance(result, (list, tuple)) else result if ws > 1: # broadcast result to all processes @@ -125,7 +114,7 @@ def compute(self) -> Any: return result.item() if isinstance(result, torch.Tensor) else result def _reduce(self, scores) -> Any: - return do_metric_reduction(scores, MetricReduction.MEAN)[0] + return self.metric_fn.aggregate(scores) def attach(self, engine: Engine, name: str) -> None: """ diff --git a/monai/handlers/mean_dice.py b/monai/handlers/mean_dice.py index 6d51c534cf..f11f0b729f 100644 --- a/monai/handlers/mean_dice.py +++ b/monai/handlers/mean_dice.py @@ -29,6 +29,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -39,14 +40,14 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: mean dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. See also: :py:meth:`monai.metrics.meandice.compute_meandice` """ - metric_fn = DiceMetric( - include_background=include_background, - reduction=MetricReduction.NONE, - ) + metric_fn = DiceMetric(include_background=include_background, reduction=reduction) super().__init__( metric_fn=metric_fn, output_transform=output_transform, diff --git a/monai/handlers/regression_metrics.py b/monai/handlers/regression_metrics.py index 2320203ff6..3129f2eb59 100644 --- a/monai/handlers/regression_metrics.py +++ b/monai/handlers/regression_metrics.py @@ -28,6 +28,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -36,13 +37,14 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. See also: :py:class:`monai.metrics.MSEMetric` """ - metric_fn = MSEMetric( - reduction=MetricReduction.NONE, - ) + metric_fn = MSEMetric(reduction=reduction) super().__init__( metric_fn=metric_fn, output_transform=output_transform, @@ -61,6 +63,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -69,13 +72,14 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: mean absolute error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. See also: :py:class:`monai.metrics.MAEMetric` """ - metric_fn = MAEMetric( - reduction=MetricReduction.NONE, - ) + metric_fn = MAEMetric(reduction=reduction) super().__init__( metric_fn=metric_fn, output_transform=output_transform, @@ -94,6 +98,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -102,13 +107,14 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: root mean squared error of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. See also: :py:class:`monai.metrics.RMSEMetric` """ - metric_fn = RMSEMetric( - reduction=MetricReduction.NONE, - ) + metric_fn = RMSEMetric(reduction=reduction) super().__init__( metric_fn=metric_fn, output_transform=output_transform, @@ -128,6 +134,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -138,14 +145,14 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: PSNR of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. See also: :py:class:`monai.metrics.PSNRMetric` """ - metric_fn = PSNRMetric( - max_val=max_val, - reduction=MetricReduction.NONE, - ) + metric_fn = PSNRMetric(max_val=max_val, reduction=reduction) super().__init__( metric_fn=metric_fn, output_transform=output_transform, diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 8011dab8db..1b12dc0e96 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -9,16 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Any, Callable, Tuple, Union import torch from monai.handlers.utils import evenly_divisible_all_gather -from monai.metrics import compute_roc_auc +from monai.metrics import ROCAUCMetric from monai.utils import Average, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") EpochMetric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "EpochMetric") +reinit__is_reduced, _ = optional_import("ignite.metrics.metric", "0.4.4", exact_version, "reinit__is_reduced") class ROCAUC(EpochMetric): # type: ignore[valid-type, misc] # due to optional_import @@ -56,21 +57,19 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", ) -> None: - def _compute_fn(pred, label): - return compute_roc_auc( - y_pred=pred, - y=label, - average=Average(average), - ) - + self.metric = ROCAUCMetric(average=Average(average)) self._is_reduced: bool = False super().__init__( - compute_fn=_compute_fn, + compute_fn=lambda p, y: self.metric.aggregate(data=(p, y)), output_transform=output_transform, check_compute_fn=False, device=device, ) + @reinit__is_reduced + def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + super().update(output=self.metric(output[0], output[1])) + def compute(self) -> Any: _prediction_tensor = torch.cat(self._predictions, dim=0) _target_tensor = torch.cat(self._targets, dim=0) diff --git a/monai/handlers/surface_distance.py b/monai/handlers/surface_distance.py index 7c2322354a..730507adc8 100644 --- a/monai/handlers/surface_distance.py +++ b/monai/handlers/surface_distance.py @@ -31,6 +31,7 @@ def __init__( output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", save_details: bool = True, + reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: """ @@ -45,13 +46,16 @@ def __init__( device: device specification in case of distributed computation usage. save_details: whether to save metric computation details per image, for example: surface dice of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key. + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. """ metric_fn = SurfaceDistanceMetric( include_background=include_background, symmetric=symmetric, distance_metric=distance_metric, - reduction=MetricReduction.NONE, + reduction=reduction, ) super().__init__( metric_fn=metric_fn, diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 3113090c62..51d0d03b0c 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -13,7 +13,8 @@ from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance from .meandice import DiceMetric, compute_meandice +from .metric import Metric from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric -from .rocauc import compute_roc_auc +from .rocauc import ROCAUCMetric, compute_roc_auc from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance from .utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py index 9c15b320eb..554aeef146 100644 --- a/monai/metrics/confusion_matrix.py +++ b/monai/metrics/confusion_matrix.py @@ -10,15 +10,17 @@ # limitations under the License. import warnings -from typing import Sequence, Union +from typing import Optional, Sequence, Union import torch from monai.metrics.utils import do_metric_reduction, ignore_background -from monai.utils import MetricReduction +from monai.utils import MetricReduction, ensure_tuple +from .metric import Metric -class ConfusionMatrixMetric: + +class ConfusionMatrixMetric(Metric): """ Compute confusion matrix related metrics. This function supports to calculate all metrics mentioned in: `Confusion matrix `_. @@ -43,14 +45,10 @@ class ConfusionMatrixMetric: Except for input only one metric, multiple metrics are also supported via input a sequence of metric names, such as ("sensitivity", "precision", "recall"), if ``compute_sample`` is ``True``, multiple ``f`` and ``not_nans`` will be returned with the same order as input names when calling the class. - compute_sample: if ``True``, each sample's metric will be computed first. If ``False``, the confusion matrix for each image - (the output of function ``get_confusion_matrix``) will be returned. In this way, users should achieve the confusion - matrixes for all images during an epoch and then use ``compute_confusion_matrix_metric`` to calculate the metric. - Defaults to ``False``. + compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first. + if ``False``, compute reduction on the confusion matrices first, defaults to ``False``. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Reduction will only be employed when - ``compute_sample`` is ``True``. Defaults to ``"mean"``. """ @@ -63,11 +61,11 @@ def __init__( ) -> None: super().__init__() self.include_background = include_background - self.metric_name = metric_name + self.metric_name = ensure_tuple(metric_name) self.compute_sample = compute_sample self.reduction = reduction - def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + def _compute(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): """ Args: y_pred: input data to compute. It must be one-hot format and first dim is batch. @@ -78,9 +76,11 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): ValueError: when `y` is not a binarized tensor. ValueError: when `y_pred` has less than two dimensions. """ + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") # check binarized input if not torch.all(y_pred.byte() == y_pred): - warnings.warn("y_pred is not a binarized tensor here!") + warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): raise ValueError("y should be a binarized tensor.") # check dimension @@ -92,27 +92,31 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): warnings.warn("As for classification task, compute_sample should be False.") self.compute_sample = False - confusion_matrix = get_confusion_matrix( + return get_confusion_matrix( y_pred=y_pred, y=y, include_background=self.include_background, ) - if self.compute_sample: - if isinstance(self.metric_name, str): - confusion_matrix = compute_confusion_matrix_metric(self.metric_name, confusion_matrix) - f, not_nans = do_metric_reduction(confusion_matrix, self.reduction) - return f, not_nans - if len(self.metric_name) < 1: - raise ValueError("the sequence should at least has on metric name.") - results = [] - for metric_name in self.metric_name: - sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, confusion_matrix) + def aggregate(self, data: torch.Tensor): + """ + Execute reduction for the confusion matrix values, the `data` usually is a Tensor of shape [BC4], + Where, the third dimension represents the number of true positive, false positive, true negative + and false negative values for each channel of each sample within the input batch. Where, B equals + to the batch size and C equals to the number of classes that need to be computed. + + """ + results = [] + for metric_name in self.metric_name: + if self.compute_sample: + sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, data) f, not_nans = do_metric_reduction(sub_confusion_matrix, self.reduction) - results.append(f) - results.append(not_nans) - return results - return confusion_matrix + else: + f, not_nans = do_metric_reduction(data, self.reduction) + f = compute_confusion_matrix_metric(metric_name, f) + results.append(f) + results.append(not_nans) + return results def get_confusion_matrix( diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py index 9617c0365a..69f9189cc8 100644 --- a/monai/metrics/hausdorff_distance.py +++ b/monai/metrics/hausdorff_distance.py @@ -18,18 +18,19 @@ from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction +from .metric import Metric + __all__ = ["HausdorffDistanceMetric", "compute_hausdorff_distance", "compute_percent_hausdorff_distance"] -class HausdorffDistanceMetric: +class HausdorffDistanceMetric(Metric): """ Compute Hausdorff Distance between two tensors. It can support both multi-classes and multi-labels tasks. It supports both directed and non-directed Hausdorff distance calculation. In addition, specify the `percentile` - parameter can get the percentile of the distance. - Input `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + parameter can get the percentile of the distance. Input `y_pred` is compared with ground truth `y`. `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. - + `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]). The implementation refers to `DeepMind's implementation `_. Args: @@ -43,7 +44,7 @@ class HausdorffDistanceMetric: directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + Define the mode to reduce computation result. Defaults to ``"mean"``. """ @@ -62,7 +63,7 @@ def __init__( self.directed = directed self.reduction = reduction - def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + def _compute(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): """ Args: y_pred: input data to compute, typical segmentation model output. @@ -75,15 +76,17 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): ValueError: when `y` is not a binarized tensor. ValueError: when `y_pred` has less than three dimensions. """ + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") if not torch.all(y_pred.byte() == y_pred): - warnings.warn("y_pred is not a binarized tensor here!") + warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): raise ValueError("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") # compute (BxC) for each channel for each batch - f = compute_hausdorff_distance( + return compute_hausdorff_distance( y_pred=y_pred, y=y, include_background=self.include_background, @@ -92,8 +95,13 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): directed=self.directed, ) + def aggregate(self, data: torch.Tensor): + """ + Execute reduction logic for the output of `compute_hausdorff_distance`. + + """ # do metric reduction - f, not_nans = do_metric_reduction(f, self.reduction) + f, not_nans = do_metric_reduction(data, self.reduction) return f, not_nans diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 9d27fff56f..bdb93f4794 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -10,31 +10,34 @@ # limitations under the License. import warnings -from typing import Union +from typing import Optional, Union import torch from monai.metrics.utils import do_metric_reduction, ignore_background from monai.utils import MetricReduction +from .metric import Metric -class DiceMetric: + +class DiceMetric(Metric): """ Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. - Input `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + Input `y_pred` is compared with ground truth `y`. `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. The `include_background` parameter can be set to ``False`` for an instance of DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. If the non-background segmentations are small compared to the total image size they can get overwhelmed by the signal from the background so excluding it in such cases helps convergence. + `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]). Args: include_background: whether to skip Dice computation on the first channel of the predicted output. Defaults to ``True``. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + Define the mode to reduce computation result. Defaults to ``"mean"``. """ @@ -47,7 +50,7 @@ def __init__( self.include_background = include_background self.reduction = reduction - def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + def _compute(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): """ Args: y_pred: input data to compute, typical segmentation model output. @@ -60,22 +63,29 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): ValueError: when `y` is not a binarized tensor. ValueError: when `y_pred` has less than three dimensions. """ + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") if not torch.all(y_pred.byte() == y_pred): - warnings.warn("y_pred is not a binarized tensor here!") + warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): raise ValueError("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") # compute dice (BxC) for each channel for each batch - f = compute_meandice( + return compute_meandice( y_pred=y_pred, y=y, include_background=self.include_background, ) + def aggregate(self, data: torch.Tensor): + """ + Execute reduction logic for the output of `compute_meandice`. + + """ # do metric reduction - f, not_nans = do_metric_reduction(f, self.reduction) + f, not_nans = do_metric_reduction(data, self.reduction) return f, not_nans diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py new file mode 100644 index 0000000000..00cd9f7b73 --- /dev/null +++ b/monai/metrics/metric.py @@ -0,0 +1,91 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +import torch + +from monai.config import TensorOrList + + +class Metric(ABC): + """ + Base class of Metrics interface. + `__call__` is supposed to compute independent logic for several samples of `y_pred` and `y`(optional). + Ususally, subclass only needs to implement the `_apply` function for computation process. + And `reduce` is supposed to execute reduction for the final result, it can be used for 1 batch data + or for the accumulated overall data. + + """ + + def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): + """ + Execute basic computation for model prediction and ground truth. + It can support both `list of channel-first Tensor` and `batch-first Tensor`. + And users can execute on every batch of data, then accumulate the results, or + accumulate the original `y_pred` and `y`, then execute on the accumulated data. + + Args: + y_pred: the model prediction data to compute, must be a list of `channel-first` Tensor + or a `batch-first` Tensor. + y: the ground truth to compute, must be a list of `channel-first` Tensor + or a `batch-first` Tensor. + + """ + ret: TensorOrList + if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)): + # if y_pred or y is a list of channel-first data, add batch dim and compute metric + ret_: List[torch.Tensor] = self._compute_list(y_pred, y) + # concat the list of results + if isinstance(ret_[0], torch.Tensor): + ret = torch.cat(ret_, dim=0) + elif isinstance(ret_[0], (list, tuple)) and all([isinstance(i, torch.Tensor) for i in ret_[0]]): + # if _compute() returned not only 1 Tensor, concat them separately + ret = [torch.cat([k[i] for k in ret_], dim=0) for i in range(len(ret_[0]))] + else: + # if not expected data type, return raw results directly + ret = ret_ + elif isinstance(y_pred, torch.Tensor): + y_ = y.detach() if y is not None and isinstance(y, torch.Tensor) else None + ret = self._compute(y_pred.detach(), y_) + else: + raise ValueError("y_pred or y must be a list of `channel-first` Tensors or a `batch-first` Tensor.") + + return ret + + def _compute_list(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): + """ + Excute the computation for every item of a list. + Subclass may enhance the operation with multi-threads to accelerate. + + """ + if y is not None: + return [self._compute(p_.detach().unsqueeze(0), y_.detach().unsqueeze(0)) for p_, y_ in zip(y_pred, y)] + else: + return [self._compute(p_.detach().unsqueeze(0), None) for p_ in y_pred] + + @abstractmethod + def _compute(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): + """ + Actual computation logic of the metric, input data should be `batch-first` Tensor. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def aggregate(self, data: Any): + """ + Aggregate the metric results. Users can call it for the batch data of every iteration + or accumulte the results of every iteration and call it for the final output. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index 78c256f9e8..4ea32cb276 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -10,23 +10,38 @@ # limitations under the License. import math -from abc import ABC, abstractmethod +from abc import abstractmethod from functools import partial -from typing import Any, Union +from typing import Any, Optional, Union import torch from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction +from .metric import Metric + + +class RegressionMetric(Metric): + """ + Base class for regression metrics. + Input `y_pred` is compared with ground truth `y`. + Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. + `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]). + + Args: + reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``} + Define the mode to reduce computation result. Defaults to ``"mean"``. + + """ -class RegressionMetric(ABC): def __init__(self, reduction: Union[MetricReduction, str] = MetricReduction.MEAN) -> None: super().__init__() self.reduction = reduction - def _reduce(self, f: torch.Tensor): - return do_metric_reduction(f, self.reduction) + def aggregate(self, data: torch.Tensor): + return do_metric_reduction(data, self.reduction) def _check_shape(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: if y_pred.shape != y.shape: @@ -42,11 +57,11 @@ def _check_shape(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + def _compute(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") self._check_shape(y_pred, y) - out = self._compute_metric(y_pred, y) - y, not_nans = self._reduce(out) - return y, not_nans + return self._compute_metric(y_pred, y) class MSEMetric(RegressionMetric): @@ -57,7 +72,7 @@ class MSEMetric(RegressionMetric): More info: https://en.wikipedia.org/wiki/Mean_squared_error - Input `y_pred` (BCHW[D] where C is number of channels) is compared with ground truth `y` (BCHW[D]). + Input `y_pred` is compared with ground truth `y`. Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: @@ -87,7 +102,7 @@ class MAEMetric(RegressionMetric): More info: https://en.wikipedia.org/wiki/Mean_absolute_error - Input `y_pred` (BCHW[D] where C is number of channels) is compared with ground truth `y` (BCHW[D]). + Input `y_pred` is compared with ground truth `y`. Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: @@ -118,7 +133,7 @@ class RMSEMetric(RegressionMetric): More info: https://en.wikipedia.org/wiki/Root-mean-square_deviation - Input `y_pred` (BCHW[D] where C is number of channels) is compared with ground truth `y` (BCHW[D]). + Input `y_pred` is compared with ground truth `y`. Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: @@ -153,7 +168,7 @@ class PSNRMetric(RegressionMetric): Help taken from: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/image_ops_impl.py line 4139 - Input `y_pred` (BCHW[D] where C is number of channels) is compared with ground truth `y` (BCHW[D]). + Input `y_pred` is compared with ground truth `y`. Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. Args: diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index a6d70b6dd8..2f3e01a821 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -9,15 +9,71 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, cast +from typing import Optional, Tuple, Union, cast import numpy as np import torch from monai.utils import Average +from .metric import Metric -def _calculate(y: torch.Tensor, y_pred: torch.Tensor) -> float: + +class ROCAUCMetric(Metric): + """ + Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: + `sklearn.metrics.roc_auc_score `_. + The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor. + + Args: + average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} + Type of averaging performed if not binary classification. + Defaults to ``"macro"``. + + - ``"macro"``: calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + - ``"weighted"``: calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - ``"micro"``: calculate metrics globally by considering each element of the label + indicator matrix as a label. + - ``"none"``: the scores for each class are returned. + + """ + + def __init__(self, average: Union[Average, str] = Average.MACRO) -> None: + super().__init__() + self.average = average + + def _compute(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): + return y_pred, y + + def aggregate(self, data: Tuple[torch.Tensor, torch.Tensor]): + """ + As AUC metric needs to execute on the overall data, so usually users accumulate `y_pred` and `y` + of every iteration, then execute real computation and reduction on the accumulated data. + For example:: + + y_pred = [] + y = [] + metric = ROCAUCMetric(average=Average.MACRO) + + for batch in dataloader: + image, label = batch + pred = model(image) + pred_, y_ = metric(pred, label) + y.append(y_) + y_pred.append(pred_) + + result = metric.aggregate(torch.cat(y_pred, dim=0), torch.cat(y, dim=0)) + + """ + y_pred, y = data + # compute final value and do metric reduction + return compute_roc_auc(y_pred=y_pred, y=y, average=self.average) + + +def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float: if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)): raise AssertionError("y and y_pred must be 1 dimension data with same length.") if not y.unique().equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)): @@ -96,16 +152,16 @@ def compute_roc_auc( y = y.squeeze(dim=-1) if y_pred_ndim == 1: - return _calculate(y, y_pred) + return _calculate(y_pred, y) if y.shape != y_pred.shape: raise AssertionError("data shapes of y_pred and y do not match.") average = Average(average) if average == Average.MICRO: - return _calculate(y.flatten(), y_pred.flatten()) + return _calculate(y_pred.flatten(), y.flatten()) y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) - auc_values = [_calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred)] + auc_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)] if average == Average.NONE: return auc_values if average == Average.MACRO: diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py index d4b2a84572..7b4519e04a 100644 --- a/monai/metrics/surface_distance.py +++ b/monai/metrics/surface_distance.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Union +from typing import Optional, Union import numpy as np import torch @@ -18,14 +18,17 @@ from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background from monai.utils import MetricReduction +from .metric import Metric -class SurfaceDistanceMetric: + +class SurfaceDistanceMetric(Metric): """ Compute Surface Distance between two tensors. It can support both multi-classes and multi-labels tasks. It supports both symmetric and asymmetric surface distance calculation. - Input `y_pred` (BNHW[D] where N is number of classes) is compared with ground truth `y` (BNHW[D]). + Input `y_pred` is compared with ground truth `y`. `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values. + `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]). Args: include_background: whether to skip distance computation on the first channel of @@ -36,7 +39,7 @@ class SurfaceDistanceMetric: the metric used to compute surface distance. Defaults to ``"euclidean"``. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} - Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. + Define the mode to reduce computation result. Defaults to ``"mean"``. """ @@ -53,7 +56,7 @@ def __init__( self.symmetric = symmetric self.reduction = reduction - def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): + def _compute(self, y_pred: torch.Tensor, y: Optional[torch.Tensor] = None): """ Args: y_pred: input data to compute, typical segmentation model output. @@ -66,15 +69,17 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): ValueError: when `y` is not a binarized tensor. ValueError: when `y_pred` has less than three dimensions. """ + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise ValueError("y_pred and y must be PyTorch Tensor.") if not torch.all(y_pred.byte() == y_pred): - warnings.warn("y_pred is not a binarized tensor here!") + warnings.warn("y_pred should be a binarized tensor.") if not torch.all(y.byte() == y): raise ValueError("y should be a binarized tensor.") dims = y_pred.ndimension() if dims < 3: raise ValueError("y_pred should have at least three dimensions.") # compute (BxC) for each channel for each batch - f = compute_average_surface_distance( + return compute_average_surface_distance( y_pred=y_pred, y=y, include_background=self.include_background, @@ -82,8 +87,13 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): distance_metric=self.distance_metric, ) + def aggregate(self, data: torch.Tensor): + """ + Execute reduction logic for the output of `compute_average_surface_distance`. + + """ # do metric reduction - f, not_nans = do_metric_reduction(f, self.reduction) + f, not_nans = do_metric_reduction(data, self.reduction) return f, not_nans diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py index 56ca5371ab..7e9da9851b 100644 --- a/tests/test_compute_confusion_matrix.py +++ b/tests/test_compute_confusion_matrix.py @@ -16,7 +16,12 @@ import torch from parameterized import parameterized -from monai.metrics import ConfusionMatrixMetric, get_confusion_matrix +from monai.metrics import ( + ConfusionMatrixMetric, + compute_confusion_matrix_metric, + do_metric_reduction, + get_confusion_matrix, +) # input data data: Dict[Any, Any] = { @@ -59,6 +64,8 @@ "y": torch.tensor([[1, 0, 0], [0, 1, 0]]), "compute_sample": False, "include_background": True, + "metric_name": "tpr", + "reduction": "mean_channel", } # 1. test confusion matrix @@ -224,7 +231,7 @@ def test_compute_sample(self, input_data, expected_value): vals["y_pred"] = params.pop("y_pred") vals["y"] = params.pop("y") metric = ConfusionMatrixMetric(**params) - result, _ = metric(**vals) + result, _ = metric.aggregate(metric(**vals)) np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE_MULTI_METRICS) @@ -234,7 +241,7 @@ def test_compute_sample_multiple_metrics(self, input_data, expected_values): vals["y_pred"] = params.pop("y_pred") vals["y"] = params.pop("y") metric = ConfusionMatrixMetric(**params) - results = metric(**vals) + results = metric.aggregate(metric(**vals)) for idx in range(0, len(results), 2): result = results[idx] expected_value = expected_values[int(idx / 2)] @@ -247,7 +254,7 @@ def test_compute_sample_with_nan(self, input_data, expected_value, expected_not_ vals["y_pred"] = params.pop("y_pred") vals["y"] = params.pop("y") metric = ConfusionMatrixMetric(**params) - result, not_nans = metric(**vals) + result, not_nans = metric.aggregate(metric(**vals)) np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(not_nans, expected_not_nans, atol=1e-4, rtol=1e-4) @@ -260,6 +267,10 @@ def test_clf_with_nan(self, input_data, expected_value): metric = ConfusionMatrixMetric(**params) result = metric(**vals) np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) + result, _ = metric.aggregate(result) + expected_value, _ = do_metric_reduction(expected_value, "mean_channel") + expected_value = compute_confusion_matrix_metric("tpr", expected_value) + np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py index 64f38dcdb8..c6763f59d5 100644 --- a/tests/test_compute_meandice.py +++ b/tests/test_compute_meandice.py @@ -167,6 +167,14 @@ [[1.0000, 1.0000], [1.0000, 1.0000]], ] +TEST_CASE_10 = [ + { + "y": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], + "y_pred": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], + }, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] + class TestComputeMeanDice(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9]) @@ -180,7 +188,7 @@ def test_nans(self, input_data, expected_value): self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value)) # DiceMetric class tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10]) def test_value_class(self, input_data, expected_value): # same test as for compute_meandice @@ -188,14 +196,16 @@ def test_value_class(self, input_data, expected_value): vals["y_pred"] = input_data.pop("y_pred") vals["y"] = input_data.pop("y") dice_metric = DiceMetric(**input_data, reduction="none") - result, _ = dice_metric(**vals) + result = dice_metric(**vals) + result, _ = dice_metric.aggregate(result) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]) def test_nans_class(self, params, input_data, expected_value): dice_metric = DiceMetric(**params) - result, _ = dice_metric(**input_data) + result = dice_metric(**input_data) + result, _ = dice_metric.aggregate(result) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py index 20d37a1d70..2b9157444a 100644 --- a/tests/test_compute_regression_metrics.py +++ b/tests/test_compute_regression_metrics.py @@ -65,16 +65,20 @@ def test_shape_reduction(self): # iterate over regression metrics, check shape for diff. reduction func for mt_fn in metrics: - out_tensor, _ = mt_fn(reduction="mean")(in_tensor, in_tensor) + mt = mt_fn(reduction="mean") + out_tensor, _ = mt.aggregate(mt(in_tensor, in_tensor)) self.assertTrue(len(out_tensor.shape) == 1) - out_tensor, _ = mt_fn(reduction="sum")(in_tensor, in_tensor) + mt = mt_fn(reduction="sum") + out_tensor, _ = mt.aggregate(mt(in_tensor, in_tensor)) self.assertTrue(len(out_tensor.shape) == 0) - out_tensor, _ = mt_fn(reduction="mean_channel")(in_tensor, in_tensor) + mt = mt_fn(reduction="mean_channel") + out_tensor, _ = mt.aggregate(mt(in_tensor, in_tensor)) self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch) - out_tensor, _ = mt_fn(reduction="sum_channel")(in_tensor, in_tensor) + mt = mt_fn(reduction="sum_channel") + out_tensor, _ = mt.aggregate(mt(in_tensor, in_tensor)) self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch) def test_compare_numpy(self): @@ -101,7 +105,8 @@ def test_compare_numpy(self): # check metrics for mt_fn, mt_fn_np in zip(metrics, metrics_np): - out_tensor, _ = mt_fn(reduction="mean")(y_pred=in_tensor_a, y=in_tensor_b) + mt = mt_fn(reduction="mean") + out_tensor, _ = mt.aggregate(mt(y_pred=in_tensor_a, y=in_tensor_b)) out_np = mt_fn_np(y_pred=in_tensor_a.cpu().numpy(), y=in_tensor_b.cpu().numpy()) np.testing.assert_allclose(out_tensor.cpu().numpy(), out_np, atol=1e-4) @@ -118,14 +123,14 @@ def test_ill_shape(self): with self.assertRaises(ValueError): in_tensor = torch.rand((basedim,)).to(device) for mt_fn in metrics: - out_tensor, _ = mt_fn()(in_tensor, in_tensor) + out_tensor = mt_fn()(in_tensor, in_tensor) # different shape for pred/target with self.assertRaises(ValueError): in_tensor_a = torch.rand((basedim,)).to(device) in_tensor_b = torch.rand((basedim, basedim)).to(device) for mt_fn in metrics: - out_tensor, _ = mt_fn()(y_pred=in_tensor_a, y=in_tensor_b) + out_tensor = mt_fn()(y_pred=in_tensor_a, y=in_tensor_b) def test_same_input(self): set_determinism(seed=123) @@ -148,7 +153,8 @@ def test_same_input(self): # check metrics for mt_fn, rs in zip(metrics, results): - out_tensor, _ = mt_fn(reduction="mean")(in_tensor, in_tensor) + mt = mt_fn(reduction="mean") + out_tensor, _ = mt.aggregate(mt(in_tensor, in_tensor)) np.testing.assert_allclose(out_tensor.cpu(), rs, atol=1e-4) def test_diff_input(self): @@ -173,7 +179,8 @@ def test_diff_input(self): # check metrics for mt_fn, rs in zip(metrics, results): - out_tensor, _ = mt_fn(reduction="mean")(in_tensor_a, in_tensor_b) + mt = mt_fn(reduction="mean") + out_tensor, _ = mt.aggregate(mt(in_tensor_a, in_tensor_b)) np.testing.assert_allclose(out_tensor.cpu(), rs, atol=1e-4) diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 10141ce0a7..acfcc022bb 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -15,7 +15,7 @@ import torch from parameterized import parameterized -from monai.metrics import compute_roc_auc +from monai.metrics import ROCAUCMetric, compute_roc_auc from monai.transforms import Activations, AsDiscrete TEST_CASE_1 = [ @@ -90,6 +90,14 @@ def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): result = compute_roc_auc(y_pred=y_pred, y=y, average=average) np.testing.assert_allclose(expected_value, result, rtol=1e-5) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred = Activations(softmax=softmax)(y_pred) + y = AsDiscrete(to_onehot=to_onehot, n_classes=2)(y) + metric = ROCAUCMetric(average=average) + result = metric.aggregate(metric(y_pred=y_pred, y=y)) + np.testing.assert_allclose(expected_value, result, rtol=1e-5) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py index 0524676763..0c6e36066b 100644 --- a/tests/test_handler_confusion_matrix.py +++ b/tests/test_handler_confusion_matrix.py @@ -58,9 +58,9 @@ class TestHandlerConfusionMatrix(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_compute(self, input_params, expected_avg): metric = ConfusionMatrix(**input_params) - - y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) - y = torch.Tensor([[[0], [1]], [[0], [1]]]) + # test input a list of channel-first tensor + y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] + y = [torch.Tensor([[0], [1]]), torch.Tensor([[0], [1]])] metric.update([y_pred, y]) y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py index c0d2e723ca..bbc36cc2b5 100644 --- a/tests/test_handler_hausdorff_distance.py +++ b/tests/test_handler_hausdorff_distance.py @@ -49,7 +49,8 @@ def create_spherical_seg_3d( sampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0) -sampler_sphere_gt = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0).unsqueeze(0) +# test input a list of channel-first tensor +sampler_sphere_gt = [torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0)] sampler_sphere_zeros = torch.zeros_like(sampler_sphere) TEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt] diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py index 6b4bea594e..57c8cf4722 100644 --- a/tests/test_handler_mean_dice.py +++ b/tests/test_handler_mean_dice.py @@ -34,8 +34,8 @@ def _val_func(engine, batch): engine = Engine(_val_func) dice_metric.attach(engine=engine, name="mean_dice") - - y_pred = torch.Tensor([[[0], [1]], [[1], [0]]]) + # test input a list of channel-first tensor + y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])] y = torch.Tensor([[[0], [1]], [[0], [1]]]) dice_metric.update([y_pred, y]) diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 04e4d3edb3..36bb499cba 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -34,6 +34,8 @@ def test_compute(self): y = torch.Tensor([[0], [1]]) y_pred = act(y_pred) y = to_onehot(y) + # test a list of channel-first tensors + y_pred, y = list(y_pred), list(y) auc_metric.update([y_pred, y]) auc = auc_metric.compute() diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py index fbd86edb03..82cdb50d90 100644 --- a/tests/test_handler_surface_distance.py +++ b/tests/test_handler_surface_distance.py @@ -49,7 +49,8 @@ def create_spherical_seg_3d( sampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0) -sampler_sphere_gt = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0).unsqueeze(0) +# test input a list of channel-first tensor +sampler_sphere_gt = [torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0)] sampler_sphere_zeros = torch.zeros_like(sampler_sphere) TEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt] diff --git a/tests/test_hausdorff_distance.py b/tests/test_hausdorff_distance.py index 465900c12a..384ae82f1f 100644 --- a/tests/test_hausdorff_distance.py +++ b/tests/test_hausdorff_distance.py @@ -131,7 +131,7 @@ def test_value(self, input_data, expected_value): batch, n_class = 2, 3 batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) - result, _ = hd_metric(batch_seg_1, batch_seg_2) + result, _ = hd_metric.aggregate(hd_metric(batch_seg_1, batch_seg_2)) expected_value_curr = expected_value[ct] np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 @@ -144,7 +144,7 @@ def test_nans(self, input_data): hd_metric = HausdorffDistanceMetric(include_background=False) batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0) batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0) - result, not_nans = hd_metric(batch_seg_1, batch_seg_2) + result, not_nans = hd_metric.aggregate(hd_metric(batch_seg_1, batch_seg_2)) np.testing.assert_allclose(0, result, rtol=1e-7) np.testing.assert_allclose(0, not_nans, rtol=1e-7) diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index af97236eda..8b15d3fc56 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -149,7 +149,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0): val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model)) - value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels) + value, not_nans = dice_metric.aggregate(dice_metric(y_pred=val_outputs, y=val_labels)) metric_count += not_nans.item() metric_sum += value.item() * not_nans.item() metric = metric_sum / metric_count @@ -218,7 +218,7 @@ def run_inference_test(root_dir, device="cuda:0"): # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model)) - value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels) + value, not_nans = dice_metric.aggregate(dice_metric(y_pred=val_outputs, y=val_labels)) metric_count += not_nans.item() metric_sum += value.item() * not_nans.item() saver.save_batch(val_outputs, val_data["img_meta_dict"]) diff --git a/tests/test_surface_distance.py b/tests/test_surface_distance.py index db90c87938..a80c06d463 100644 --- a/tests/test_surface_distance.py +++ b/tests/test_surface_distance.py @@ -136,7 +136,7 @@ def test_value(self, input_data, expected_value): batch, n_class = 2, 3 batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1]) - result, _ = sur_metric(batch_seg_1, batch_seg_2) + result, _ = sur_metric.aggregate(sur_metric(batch_seg_1, batch_seg_2)) expected_value_curr = expected_value[ct] np.testing.assert_allclose(expected_value_curr, result, rtol=1e-7) ct += 1 @@ -147,9 +147,10 @@ def test_nans(self, input_data): seg_1 = torch.tensor(seg_1) seg_2 = torch.tensor(seg_2) sur_metric = SurfaceDistanceMetric(include_background=False) - batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0) - batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0) - result, not_nans = sur_metric(batch_seg_1, batch_seg_2) + # test list of channel-first Tensor + batch_seg_1 = [seg_1.unsqueeze(0)] + batch_seg_2 = [seg_2.unsqueeze(0)] + result, not_nans = sur_metric.aggregate(sur_metric(batch_seg_1, batch_seg_2)) np.testing.assert_allclose(0, result, rtol=1e-7) np.testing.assert_allclose(0, not_nans, rtol=1e-7)