Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,22 @@ def __init__(
self.lambda_dice = lambda_dice
self.lambda_ce = lambda_ce

def ce(self, input: torch.Tensor, target: torch.Tensor):
"""
Compute CrossEntropy loss for the input and target.
Will remove the channel dim according to PyTorch CrossEntropyLoss:
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss.

"""
n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
if n_pred_ch == n_target_ch:
# target is in the one-hot format, convert to BH[WD] format to calculate ce loss
target = torch.argmax(target, dim=1)
else:
target = torch.squeeze(target, dim=1)
target = target.long()
return self.cross_entropy(input, target)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand All @@ -683,16 +699,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
raise ValueError("the number of dimensions for input and target should be the same.")

dice_loss = self.dice(input, target)

n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
if n_pred_ch == n_target_ch:
# target is in the one-hot format, convert to BH[WD] format to calculate ce loss
target = torch.argmax(target, dim=1)
else:
target = torch.squeeze(target, dim=1)
target = target.long()
ce_loss = self.cross_entropy(input, target)
ce_loss = self.ce(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss

return total_loss


Expand Down Expand Up @@ -806,6 +815,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
dice_loss = self.dice(input, target)
focal_loss = self.focal(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss

return total_loss


Expand Down
3 changes: 2 additions & 1 deletion tests/test_dice_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@
class TestDiceCELoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_result(self, input_param, input_data, expected_val):
result = DiceCELoss(**input_param)(**input_data)
diceceloss = DiceCELoss(**input_param)
result = diceceloss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)

def test_ill_shape(self):
Expand Down