Describe the bug
monai.metrics.compute_roc_auc doesn't accept labels not in the one-hot format for multi-class classification.
To Reproduce
from monai.metrics import compute_roc_auc
import torch
y = torch.randint(4, size=(10,1))
y_pred = torch.randint(4, size=(10,4))
y_pred = y_pred/y_pred.sum(dim=1).view(-1,1)
compute_roc_auc(y_pred, y)
Expected behavior
In the documentation, it says that if the label has a shape: [16, 1], that label will be converted into [16, 2] (where 2 is inferred from y_pred). The function should internally convert the label into the one-hot format or the documentation should explicitly ask for that.
Additional context
Potential fix for converting the label internally:
Adding from monai.networks.utils import one_hot ,
and
if y_pred_ndim == 2 and y.ndimension()==1:
y = one_hot(y, num_classes=y_pred.shape[1])
after
|
if y_pred_ndim == 2 and y_pred.shape[1] == 1: |
|
y_pred = y_pred.squeeze(dim=-1) |
|
y_pred_ndim = 1 |
|
if y_ndim == 2 and y.shape[1] == 1: |
|
y = y.squeeze(dim=-1) |
Describe the bug
monai.metrics.compute_roc_auc doesn't accept labels not in the one-hot format for multi-class classification.
To Reproduce
Expected behavior
In the documentation, it says that if the label has a shape: [16, 1], that label will be converted into [16, 2] (where 2 is inferred from y_pred). The function should internally convert the label into the one-hot format or the documentation should explicitly ask for that.
Additional context
Potential fix for converting the label internally:
Adding
from monai.networks.utils import one_hot,and
after
MONAI/monai/metrics/rocauc.py
Lines 137 to 141 in 1ecf5b6