diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 3094b7a747..4af737064f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -668,6 +668,22 @@ def __init__( self.lambda_dice = lambda_dice self.lambda_ce = lambda_ce + def ce(self, input: torch.Tensor, target: torch.Tensor): + """ + Compute CrossEntropy loss for the input and target. + Will remove the channel dim according to PyTorch CrossEntropyLoss: + https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. + + """ + n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + if n_pred_ch == n_target_ch: + # target is in the one-hot format, convert to BH[WD] format to calculate ce loss + target = torch.argmax(target, dim=1) + else: + target = torch.squeeze(target, dim=1) + target = target.long() + return self.cross_entropy(input, target) + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: @@ -683,16 +699,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError("the number of dimensions for input and target should be the same.") dice_loss = self.dice(input, target) - - n_pred_ch, n_target_ch = input.shape[1], target.shape[1] - if n_pred_ch == n_target_ch: - # target is in the one-hot format, convert to BH[WD] format to calculate ce loss - target = torch.argmax(target, dim=1) - else: - target = torch.squeeze(target, dim=1) - target = target.long() - ce_loss = self.cross_entropy(input, target) + ce_loss = self.ce(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss + return total_loss @@ -806,6 +815,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss + return total_loss diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 3423e1425b..66cfb36e99 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -71,7 +71,8 @@ class TestDiceCELoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): - result = DiceCELoss(**input_param)(**input_data) + diceceloss = DiceCELoss(**input_param) + result = diceceloss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) def test_ill_shape(self):