diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 5e0ccd3179..9c08658641 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -63,7 +63,7 @@ def __init__( pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) - fl = FocalLoss() + fl = FocalLoss(to_onehot_y=True) fl(pred, grnd) """