From 46b8092d92494582dcd464fa54c957864d74d9d3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 15 Jul 2021 00:04:11 +0800 Subject: [PATCH 1/3] [DLMED] add get_loss_details Signed-off-by: Nic Ma --- monai/losses/dice.py | 34 +++++++++++++++++++++++++++++++++- tests/test_dice_ce_loss.py | 5 ++++- tests/test_dice_focal_loss.py | 4 ++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index d8db2bf586..57ec141df9 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, Dict, List, Optional, Sequence, Union import numpy as np import torch @@ -665,6 +665,11 @@ def __init__( raise ValueError("lambda_ce should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_ce = lambda_ce + self.loss_details: Dict[str, torch.Tensor] = { + "total_loss": torch.zeros(0), + "dice_loss": torch.zeros(0), + "ce_loss": torch.zeros(0), + } def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -681,6 +686,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError("the number of dimensions for input and target should be the same.") dice_loss = self.dice(input, target) + self.loss_details["dice_loss"] = dice_loss.detach() n_pred_ch, n_target_ch = input.shape[1], target.shape[1] if n_pred_ch == n_target_ch: @@ -690,9 +696,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: target = torch.squeeze(target, dim=1) target = target.long() ce_loss = self.cross_entropy(input, target) + self.loss_details["ce_loss"] = ce_loss.detach() total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss + self.loss_details["total_loss"] = total_loss.detach() + return total_loss + def get_loss_details(self) -> Dict[str, torch.Tensor]: + """ + Get the raw values of DiceLoss, CELoss and the total loss. + + """ + return self.loss_details + class DiceFocalLoss(_Loss): """ @@ -783,6 +799,11 @@ def __init__( raise ValueError("lambda_focal should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_focal = lambda_focal + self.loss_details: Dict[str, torch.Tensor] = { + "total_loss": torch.zeros(0), + "dice_loss": torch.zeros(0), + "ce_loss": torch.zeros(0), + } def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -800,10 +821,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError("the number of dimensions for input and target should be the same.") dice_loss = self.dice(input, target) + self.loss_details["dice_loss"] = dice_loss.detach() focal_loss = self.focal(input, target) + self.loss_details["focal_loss"] = focal_loss.detach() total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss + self.loss_details["total_loss"] = total_loss.detach() + return total_loss + def get_loss_details(self) -> Dict[str, torch.Tensor]: + """ + Get the raw values of DiceLoss, FocalLoss and the total loss. + + """ + return self.loss_details + dice = Dice = DiceLoss dice_ce = DiceCELoss diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index 3423e1425b..d3594fd6e8 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -71,8 +71,11 @@ class TestDiceCELoss(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_result(self, input_param, input_data, expected_val): - result = DiceCELoss(**input_param)(**input_data) + diceceloss = DiceCELoss(**input_param) + result = diceceloss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + details = diceceloss.get_loss_details() + np.testing.assert_allclose(details["total_loss"], expected_val, atol=1e-4, rtol=1e-4) def test_ill_shape(self): loss = DiceCELoss() diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 920994f8de..7b2d3137bf 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -39,6 +39,8 @@ def test_result_onehot_target_include_bg(self): result = dice_focal(pred, label) expected_val = dice(pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val) + details = dice_focal.get_loss_details() + np.testing.assert_allclose(details["total_loss"], expected_val) def test_result_no_onehot_no_bg(self): size = [3, 3, 5, 5] @@ -59,6 +61,8 @@ def test_result_no_onehot_no_bg(self): result = dice_focal(pred, label) expected_val = dice(pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val) + details = dice_focal.get_loss_details() + np.testing.assert_allclose(details["total_loss"], expected_val) def test_ill_shape(self): loss = DiceFocalLoss() From acef81f2240580c3e20a540f85389fb3f51dd0af Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 16 Jul 2021 08:32:31 +0800 Subject: [PATCH 2/3] [DLMED] enhance the doc-string Signed-off-by: Nic Ma --- monai/losses/dice.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 57ec141df9..9ea82c408e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -705,6 +705,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def get_loss_details(self) -> Dict[str, torch.Tensor]: """ Get the raw values of DiceLoss, CELoss and the total loss. + It's mainly used to visualize the loss values of `DiceLoss`, `CELoss` and `TotalLoss` to tune + the weights for them. + Note: printing a GPU tensor will potentially move it to CPU and may impact the training performance: + https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#avoid-unnecessary-cpu-gpu-synchronization. """ return self.loss_details @@ -832,6 +836,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: def get_loss_details(self) -> Dict[str, torch.Tensor]: """ Get the raw values of DiceLoss, FocalLoss and the total loss. + It's mainly used to visualize the loss values of `DiceLoss`, `FocalLoss` and `TotalLoss` to tune + the weights for them. + Note: printing a GPU tensor will potentially move it to CPU and may impact the training performance: + https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#avoid-unnecessary-cpu-gpu-synchronization. """ return self.loss_details From a272fdfbd9ff714bddfb92ff315c5b68258ab0e2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 16 Jul 2021 23:54:57 +0800 Subject: [PATCH 3/3] [DLMED] update DiceCELoss to provide API to compute CE Signed-off-by: Nic Ma --- monai/losses/dice.py | 66 ++++++++++------------------------- tests/test_dice_ce_loss.py | 2 -- tests/test_dice_focal_loss.py | 4 --- 3 files changed, 18 insertions(+), 54 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index e4f3150179..4af737064f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import Callable, Dict, List, Optional, Sequence, Union +from typing import Callable, List, Optional, Sequence, Union import numpy as np import torch @@ -667,11 +667,22 @@ def __init__( raise ValueError("lambda_ce should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_ce = lambda_ce - self.loss_details: Dict[str, torch.Tensor] = { - "total_loss": torch.zeros(0), - "dice_loss": torch.zeros(0), - "ce_loss": torch.zeros(0), - } + + def ce(self, input: torch.Tensor, target: torch.Tensor): + """ + Compute CrossEntropy loss for the input and target. + Will remove the channel dim according to PyTorch CrossEntropyLoss: + https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss. + + """ + n_pred_ch, n_target_ch = input.shape[1], target.shape[1] + if n_pred_ch == n_target_ch: + # target is in the one-hot format, convert to BH[WD] format to calculate ce loss + target = torch.argmax(target, dim=1) + else: + target = torch.squeeze(target, dim=1) + target = target.long() + return self.cross_entropy(input, target) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -688,33 +699,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError("the number of dimensions for input and target should be the same.") dice_loss = self.dice(input, target) - self.loss_details["dice_loss"] = dice_loss.detach() - - n_pred_ch, n_target_ch = input.shape[1], target.shape[1] - if n_pred_ch == n_target_ch: - # target is in the one-hot format, convert to BH[WD] format to calculate ce loss - target = torch.argmax(target, dim=1) - else: - target = torch.squeeze(target, dim=1) - target = target.long() - ce_loss = self.cross_entropy(input, target) - self.loss_details["ce_loss"] = ce_loss.detach() + ce_loss = self.ce(input, target) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss - self.loss_details["total_loss"] = total_loss.detach() return total_loss - def get_loss_details(self) -> Dict[str, torch.Tensor]: - """ - Get the raw values of DiceLoss, CELoss and the total loss. - It's mainly used to visualize the loss values of `DiceLoss`, `CELoss` and `TotalLoss` to tune - the weights for them. - Note: printing a GPU tensor will potentially move it to CPU and may impact the training performance: - https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#avoid-unnecessary-cpu-gpu-synchronization. - - """ - return self.loss_details - class DiceFocalLoss(_Loss): """ @@ -807,11 +796,6 @@ def __init__( raise ValueError("lambda_focal should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_focal = lambda_focal - self.loss_details: Dict[str, torch.Tensor] = { - "total_loss": torch.zeros(0), - "dice_loss": torch.zeros(0), - "ce_loss": torch.zeros(0), - } def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -829,25 +813,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise ValueError("the number of dimensions for input and target should be the same.") dice_loss = self.dice(input, target) - self.loss_details["dice_loss"] = dice_loss.detach() focal_loss = self.focal(input, target) - self.loss_details["focal_loss"] = focal_loss.detach() total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss - self.loss_details["total_loss"] = total_loss.detach() return total_loss - def get_loss_details(self) -> Dict[str, torch.Tensor]: - """ - Get the raw values of DiceLoss, FocalLoss and the total loss. - It's mainly used to visualize the loss values of `DiceLoss`, `FocalLoss` and `TotalLoss` to tune - the weights for them. - Note: printing a GPU tensor will potentially move it to CPU and may impact the training performance: - https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#avoid-unnecessary-cpu-gpu-synchronization. - - """ - return self.loss_details - dice = Dice = DiceLoss dice_ce = DiceCELoss diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py index d3594fd6e8..66cfb36e99 100644 --- a/tests/test_dice_ce_loss.py +++ b/tests/test_dice_ce_loss.py @@ -74,8 +74,6 @@ def test_result(self, input_param, input_data, expected_val): diceceloss = DiceCELoss(**input_param) result = diceceloss(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) - details = diceceloss.get_loss_details() - np.testing.assert_allclose(details["total_loss"], expected_val, atol=1e-4, rtol=1e-4) def test_ill_shape(self): loss = DiceCELoss() diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py index 7b2d3137bf..920994f8de 100644 --- a/tests/test_dice_focal_loss.py +++ b/tests/test_dice_focal_loss.py @@ -39,8 +39,6 @@ def test_result_onehot_target_include_bg(self): result = dice_focal(pred, label) expected_val = dice(pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val) - details = dice_focal.get_loss_details() - np.testing.assert_allclose(details["total_loss"], expected_val) def test_result_no_onehot_no_bg(self): size = [3, 3, 5, 5] @@ -61,8 +59,6 @@ def test_result_no_onehot_no_bg(self): result = dice_focal(pred, label) expected_val = dice(pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val) - details = dice_focal.get_loss_details() - np.testing.assert_allclose(details["total_loss"], expected_val) def test_ill_shape(self): loss = DiceFocalLoss()