Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions monai/handlers/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union

import torch

Expand All @@ -27,10 +27,6 @@ class ROCAUC(EpochMetric): # type: ignore[valid-type, misc] # due to optional_
accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.

Args:
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
softmax: whether to add softmax function to `y_pred` before computation. Defaults to False.
other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``.
for example: `other_act = lambda x: torch.log_softmax(x)`.
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification. Defaults to ``"macro"``.

Expand All @@ -56,9 +52,6 @@ class ROCAUC(EpochMetric): # type: ignore[valid-type, misc] # due to optional_

def __init__(
self,
to_onehot_y: bool = False,
softmax: bool = False,
other_act: Optional[Callable] = None,
average: Union[Average, str] = Average.MACRO,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = "cpu",
Expand All @@ -67,9 +60,6 @@ def _compute_fn(pred, label):
return compute_roc_auc(
y_pred=pred,
y=label,
to_onehot_y=to_onehot_y,
softmax=softmax,
other_act=other_act,
average=Average(average),
)

Expand Down
28 changes: 1 addition & 27 deletions monai/metrics/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Callable, Optional, Union, cast
from typing import Union, cast

import numpy as np
import torch

from monai.networks import one_hot
from monai.utils import Average


Expand Down Expand Up @@ -53,9 +51,6 @@ def _calculate(y: torch.Tensor, y_pred: torch.Tensor) -> float:
def compute_roc_auc(
y_pred: torch.Tensor,
y: torch.Tensor,
to_onehot_y: bool = False,
softmax: bool = False,
other_act: Optional[Callable] = None,
average: Union[Average, str] = Average.MACRO,
):
"""Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
Expand All @@ -67,10 +62,6 @@ def compute_roc_auc(
it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2].
y: ground truth to compute ROC AUC metric, the first dim is batch.
example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`).
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
softmax: whether to add softmax function to `y_pred` before computation. Defaults to False.
other_act: callable function to replace `softmax` as activation layer if needed, Defaults to ``None``.
for example: `other_act = lambda x: torch.log_softmax(x)`.
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
Type of averaging performed if not binary classification.
Defaults to ``"macro"``.
Expand All @@ -86,8 +77,6 @@ def compute_roc_auc(
Raises:
ValueError: When ``y_pred`` dimension is not one of [1, 2].
ValueError: When ``y`` dimension is not one of [1, 2].
ValueError: When ``softmax=True`` and ``other_act is not None``. Incompatible values.
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

Note:
Expand All @@ -107,22 +96,7 @@ def compute_roc_auc(
y = y.squeeze(dim=-1)

if y_pred_ndim == 1:
if to_onehot_y:
warnings.warn("y_pred has only one channel, to_onehot_y=True ignored.")
if softmax:
warnings.warn("y_pred has only one channel, softmax=True ignored.")
return _calculate(y, y_pred)
n_classes = y_pred.shape[1]
if to_onehot_y:
y = one_hot(y, n_classes)
if softmax and other_act is not None:
raise ValueError("Incompatible values: softmax=True and other_act is not None.")
if softmax:
y_pred = y_pred.float().softmax(dim=1)
if other_act is not None:
if not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
y_pred = other_act(y_pred)

if y.shape != y_pred.shape:
raise AssertionError("data shapes of y_pred and y do not match.")
Expand Down
5 changes: 5 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def __call__(
if other is not None and not callable(other):
raise TypeError(f"other must be None or callable but is {type(other).__name__}.")

# convert to float as activation must operate on float tensor
img = img.float()
if sigmoid or self.sigmoid:
img = torch.sigmoid(img)
if softmax or self.softmax:
# add channel dim if not existing
if img.ndimension() == 1:
img = img.unsqueeze(-1)
img = torch.softmax(img, dim=1)

act_func = self.other if other is None else other
Expand Down
93 changes: 50 additions & 43 deletions tests/test_compute_roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,71 +16,78 @@
from parameterized import parameterized

from monai.metrics import compute_roc_auc
from monai.transforms import Activations, AsDiscrete

TEST_CASE_1 = [
{
"y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
"y": torch.tensor([[0], [1], [0], [1]]),
"to_onehot_y": True,
"softmax": True,
},
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
torch.tensor([[0], [1], [0], [1]]),
True,
True,
"macro",
0.75,
]

TEST_CASE_2 = [{"y_pred": torch.tensor([[0.5], [0.5], [0.2], [8.3]]), "y": torch.tensor([[0], [1], [0], [1]])}, 0.875]
TEST_CASE_2 = [
torch.tensor([[0.5], [0.5], [0.2], [8.3]]),
torch.tensor([[0], [1], [0], [1]]),
False,
False,
"macro",
0.875,
]

TEST_CASE_3 = [{"y_pred": torch.tensor([[0.5], [0.5], [0.2], [8.3]]), "y": torch.tensor([0, 1, 0, 1])}, 0.875]
TEST_CASE_3 = [
torch.tensor([[0.5], [0.5], [0.2], [8.3]]),
torch.tensor([0, 1, 0, 1]),
False,
False,
"macro",
0.875,
]

TEST_CASE_4 = [{"y_pred": torch.tensor([0.5, 0.5, 0.2, 8.3]), "y": torch.tensor([0, 1, 0, 1])}, 0.875]
TEST_CASE_4 = [
torch.tensor([0.5, 0.5, 0.2, 8.3]),
torch.tensor([0, 1, 0, 1]),
False,
False,
"macro",
0.875,
]

TEST_CASE_5 = [
{
"y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
"y": torch.tensor([[0], [1], [0], [1]]),
"to_onehot_y": True,
"softmax": True,
"average": "none",
},
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
torch.tensor([[0], [1], [0], [1]]),
True,
True,
"none",
[0.75, 0.75],
]

TEST_CASE_6 = [
{
"y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),
"y": torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),
"softmax": True,
"average": "weighted",
},
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),
torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),
True,
False,
"weighted",
0.56667,
]

TEST_CASE_7 = [
{
"y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),
"y": torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),
"softmax": True,
"average": "micro",
},
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),
torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),
True,
False,
"micro",
0.62,
]

TEST_CASE_8 = [
{
"y_pred": torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
"y": torch.tensor([[0], [1], [0], [1]]),
"to_onehot_y": True,
"other_act": lambda x: torch.log_softmax(x, dim=1),
},
0.75,
]


class TestComputeROCAUC(unittest.TestCase):
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
)
def test_value(self, input_data, expected_value):
result = compute_roc_auc(**input_data)
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value):
y_pred = Activations(softmax=softmax)(y_pred)
y = AsDiscrete(to_onehot=to_onehot, n_classes=2)(y)
result = compute_roc_auc(y_pred=y_pred, y=y, average=average)
np.testing.assert_allclose(expected_value, result, rtol=1e-5)


Expand Down
9 changes: 8 additions & 1 deletion tests/test_handler_rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@
import torch

from monai.handlers import ROCAUC
from monai.transforms import Activations, AsDiscrete


class TestHandlerROCAUC(unittest.TestCase):
def test_compute(self):
auc_metric = ROCAUC(to_onehot_y=True, softmax=True)
auc_metric = ROCAUC()
act = Activations(softmax=True)
to_onehot = AsDiscrete(to_onehot=True, n_classes=2)

y_pred = torch.Tensor([[0.1, 0.9], [0.3, 1.4]])
y = torch.Tensor([[0], [1]])
y_pred = act(y_pred)
y = to_onehot(y)
auc_metric.update([y_pred, y])

y_pred = torch.Tensor([[0.2, 0.1], [0.1, 0.5]])
y = torch.Tensor([[0], [1]])
y_pred = act(y_pred)
y = to_onehot(y)
auc_metric.update([y_pred, y])

auc = auc_metric.compute()
Expand Down
12 changes: 9 additions & 3 deletions tests/test_handler_rocauc_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,29 @@
import torch.distributed as dist

from monai.handlers import ROCAUC
from monai.transforms import Activations, AsDiscrete
from tests.utils import DistCall, DistTestCase


class DistributedROCAUC(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
def test_compute(self):
auc_metric = ROCAUC(to_onehot_y=True, softmax=True)
auc_metric = ROCAUC()
act = Activations(softmax=True)
to_onehot = AsDiscrete(to_onehot=True, n_classes=2)

device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
if dist.get_rank() == 0:
y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=device)
y = torch.tensor([[0], [1]], device=device)
auc_metric.update([y_pred, y])

if dist.get_rank() == 1:
y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5], [0.3, 0.4]], device=device)
y = torch.tensor([[0], [1], [1]], device=device)
auc_metric.update([y_pred, y])

y_pred = act(y_pred)
y = to_onehot(y)
auc_metric.update([y_pred, y])

result = auc_metric.compute()
np.testing.assert_allclose(0.66667, result, rtol=1e-4)
Expand Down
24 changes: 21 additions & 3 deletions tests/test_integration_classification_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@
from monai.metrics import compute_roc_auc
from monai.networks import eval_mode
from monai.networks.nets import DenseNet121
from monai.transforms import AddChannel, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, ToTensor
from monai.transforms import (
Activations,
AddChannel,
AsDiscrete,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
ToTensor,
)
from monai.utils import set_determinism
from tests.testing_data.integration_answers import test_integration_value
from tests.utils import DistTestCase, TimedCall, skip_if_quick
Expand Down Expand Up @@ -63,6 +74,8 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0",
)
train_transforms.set_random_state(1234)
val_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), ToTensor()])
act = Activations(softmax=True)
to_onehot = AsDiscrete(to_onehot=True, n_classes=len(np.unique(train_y)))

# create train, val data loaders
train_ds = MedNISTDataset(train_x, train_y, train_transforms)
Expand Down Expand Up @@ -110,10 +123,15 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0",
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
y = torch.cat([y, val_labels], dim=0)
auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True)
metric_values.append(auc_metric)

# compute accuracy
acc_value = torch.eq(y_pred.argmax(dim=1), y)
acc_metric = acc_value.sum().item() / len(acc_value)
# compute AUC
y_pred = act(y_pred)
y = to_onehot(y)
auc_metric = compute_roc_auc(y_pred, y)
metric_values.append(auc_metric)
if auc_metric > best_metric:
best_metric = auc_metric
best_metric_epoch = epoch + 1
Expand Down