diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 9a9af601f9..8011dab8db 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import torch @@ -27,10 +27,6 @@ class ROCAUC(EpochMetric): # type: ignore[valid-type, misc] # due to optional_ accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`. Args: - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. - other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``. - for example: `other_act = lambda x: torch.log_softmax(x)`. average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} Type of averaging performed if not binary classification. Defaults to ``"macro"``. @@ -56,9 +52,6 @@ class ROCAUC(EpochMetric): # type: ignore[valid-type, misc] # due to optional_ def __init__( self, - to_onehot_y: bool = False, - softmax: bool = False, - other_act: Optional[Callable] = None, average: Union[Average, str] = Average.MACRO, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = "cpu", @@ -67,9 +60,6 @@ def _compute_fn(pred, label): return compute_roc_auc( y_pred=pred, y=label, - to_onehot_y=to_onehot_y, - softmax=softmax, - other_act=other_act, average=Average(average), ) diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py index 80a6671dfa..a6d70b6dd8 100644 --- a/monai/metrics/rocauc.py +++ b/monai/metrics/rocauc.py @@ -9,13 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from typing import Callable, Optional, Union, cast +from typing import Union, cast import numpy as np import torch -from monai.networks import one_hot from monai.utils import Average @@ -53,9 +51,6 @@ def _calculate(y: torch.Tensor, y_pred: torch.Tensor) -> float: def compute_roc_auc( y_pred: torch.Tensor, y: torch.Tensor, - to_onehot_y: bool = False, - softmax: bool = False, - other_act: Optional[Callable] = None, average: Union[Average, str] = Average.MACRO, ): """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: @@ -67,10 +62,6 @@ def compute_roc_auc( it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. y: ground truth to compute ROC AUC metric, the first dim is batch. example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. - other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``. - for example: `other_act = lambda x: torch.log_softmax(x)`. average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``} Type of averaging performed if not binary classification. Defaults to ``"macro"``. @@ -86,8 +77,6 @@ def compute_roc_auc( Raises: ValueError: When ``y_pred`` dimension is not one of [1, 2]. ValueError: When ``y`` dimension is not one of [1, 2]. - ValueError: When ``softmax=True`` and ``other_act is not None``. Incompatible values. - TypeError: When ``other_act`` is not an ``Optional[Callable]``. ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"]. Note: @@ -107,22 +96,7 @@ def compute_roc_auc( y = y.squeeze(dim=-1) if y_pred_ndim == 1: - if to_onehot_y: - warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.") - if softmax: - warnings.warn("y_pred has only one channel, softmax=True ignored.") return _calculate(y, y_pred) - n_classes = y_pred.shape[1] - if to_onehot_y: - y = one_hot(y, n_classes) - if softmax and other_act is not None: - raise ValueError("Incompatible values: softmax=True and other_act is not None.") - if softmax: - y_pred = y_pred.float().softmax(dim=1) - if other_act is not None: - if not callable(other_act): - raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") - y_pred = other_act(y_pred) if y.shape != y_pred.shape: raise AssertionError("data shapes of y_pred and y do not match.") diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8b4f71093b..6462753cf9 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -86,9 +86,14 @@ def __call__( if other is not None and not callable(other): raise TypeError(f"other must be None or callable but is {type(other).__name__}.") + # convert to float as activation must operate on float tensor + img = img.float() if sigmoid or self.sigmoid: img = torch.sigmoid(img) if softmax or self.softmax: + # add channel dim if not existing + if img.ndimension() == 1: + img = img.unsqueeze(-1) img = torch.softmax(img, dim=1) act_func = self.other if other is None else other diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index 612bd375ac..10141ce0a7 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -16,71 +16,78 @@ from parameterized import parameterized from monai.metrics import compute_roc_auc +from monai.transforms import Activations, AsDiscrete TEST_CASE_1 = [ - { - "y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), - "y": torch.tensor([[0], [1], [0], [1]]), - "to_onehot_y": True, - "softmax": True, - }, + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [1], [0], [1]]), + True, + True, + "macro", 0.75, ] -TEST_CASE_2 = [{"y_pred": torch.tensor([[0.5], [0.5], [0.2], [8.3]]), "y": torch.tensor([[0], [1], [0], [1]])}, 0.875] +TEST_CASE_2 = [ + torch.tensor([[0.5], [0.5], [0.2], [8.3]]), + torch.tensor([[0], [1], [0], [1]]), + False, + False, + "macro", + 0.875, +] -TEST_CASE_3 = [{"y_pred": torch.tensor([[0.5], [0.5], [0.2], [8.3]]), "y": torch.tensor([0, 1, 0, 1])}, 0.875] +TEST_CASE_3 = [ + torch.tensor([[0.5], [0.5], [0.2], [8.3]]), + torch.tensor([0, 1, 0, 1]), + False, + False, + "macro", + 0.875, +] -TEST_CASE_4 = [{"y_pred": torch.tensor([0.5, 0.5, 0.2, 8.3]), "y": torch.tensor([0, 1, 0, 1])}, 0.875] +TEST_CASE_4 = [ + torch.tensor([0.5, 0.5, 0.2, 8.3]), + torch.tensor([0, 1, 0, 1]), + False, + False, + "macro", + 0.875, +] TEST_CASE_5 = [ - { - "y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), - "y": torch.tensor([[0], [1], [0], [1]]), - "to_onehot_y": True, - "softmax": True, - "average": "none", - }, + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), + torch.tensor([[0], [1], [0], [1]]), + True, + True, + "none", [0.75, 0.75], ] TEST_CASE_6 = [ - { - "y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), - "y": torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), - "softmax": True, - "average": "weighted", - }, + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + False, + "weighted", 0.56667, ] TEST_CASE_7 = [ - { - "y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), - "y": torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), - "softmax": True, - "average": "micro", - }, + torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]), + torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]), + True, + False, + "micro", 0.62, ] -TEST_CASE_8 = [ - { - "y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]), - "y": torch.tensor([[0], [1], [0], [1]]), - "to_onehot_y": True, - "other_act": lambda x: torch.log_softmax(x, dim=1), - }, - 0.75, -] - class TestComputeROCAUC(unittest.TestCase): - @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] - ) - def test_value(self, input_data, expected_value): - result = compute_roc_auc(**input_data) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): + y_pred = Activations(softmax=softmax)(y_pred) + y = AsDiscrete(to_onehot=to_onehot, n_classes=2)(y) + result = compute_roc_auc(y_pred=y_pred, y=y, average=average) np.testing.assert_allclose(expected_value, result, rtol=1e-5) diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 05f6eebce6..04e4d3edb3 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -15,18 +15,25 @@ import torch from monai.handlers import ROCAUC +from monai.transforms import Activations, AsDiscrete class TestHandlerROCAUC(unittest.TestCase): def test_compute(self): - auc_metric = ROCAUC(to_onehot_y=True, softmax=True) + auc_metric = ROCAUC() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=True, n_classes=2) y_pred = torch.Tensor([[0.1, 0.9], [0.3, 1.4]]) y = torch.Tensor([[0], [1]]) + y_pred = act(y_pred) + y = to_onehot(y) auc_metric.update([y_pred, y]) y_pred = torch.Tensor([[0.2, 0.1], [0.1, 0.5]]) y = torch.Tensor([[0], [1]]) + y_pred = act(y_pred) + y = to_onehot(y) auc_metric.update([y_pred, y]) auc = auc_metric.compute() diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index c5cf44162c..e768906158 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -17,23 +17,29 @@ import torch.distributed as dist from monai.handlers import ROCAUC +from monai.transforms import Activations, AsDiscrete from tests.utils import DistCall, DistTestCase class DistributedROCAUC(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) def test_compute(self): - auc_metric = ROCAUC(to_onehot_y=True, softmax=True) + auc_metric = ROCAUC() + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=True, n_classes=2) + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" if dist.get_rank() == 0: y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=device) y = torch.tensor([[0], [1]], device=device) - auc_metric.update([y_pred, y]) if dist.get_rank() == 1: y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5], [0.3, 0.4]], device=device) y = torch.tensor([[0], [1], [1]], device=device) - auc_metric.update([y_pred, y]) + + y_pred = act(y_pred) + y = to_onehot(y) + auc_metric.update([y_pred, y]) result = auc_metric.compute() np.testing.assert_allclose(0.66667, result, rtol=1e-4) diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 6f8c949d78..68493e4ffb 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -23,7 +23,18 @@ from monai.metrics import compute_roc_auc from monai.networks import eval_mode from monai.networks.nets import DenseNet121 -from monai.transforms import AddChannel, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, ToTensor +from monai.transforms import ( + Activations, + AddChannel, + AsDiscrete, + Compose, + LoadImage, + RandFlip, + RandRotate, + RandZoom, + ScaleIntensity, + ToTensor, +) from monai.utils import set_determinism from tests.testing_data.integration_answers import test_integration_value from tests.utils import DistTestCase, TimedCall, skip_if_quick @@ -63,6 +74,8 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", ) train_transforms.set_random_state(1234) val_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), ToTensor()]) + act = Activations(softmax=True) + to_onehot = AsDiscrete(to_onehot=True, n_classes=len(np.unique(train_y))) # create train, val data loaders train_ds = MedNISTDataset(train_x, train_y, train_transforms) @@ -110,10 +123,15 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", val_images, val_labels = val_data[0].to(device), val_data[1].to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) - auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True) - metric_values.append(auc_metric) + + # compute accuracy acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) + # compute AUC + y_pred = act(y_pred) + y = to_onehot(y) + auc_metric = compute_roc_auc(y_pred, y) + metric_values.append(auc_metric) if auc_metric > best_metric: best_metric = auc_metric best_metric_epoch = epoch + 1