diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 6bf9680bca..ff9e336771 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -53,6 +53,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + pixelwise: bool = False, ) -> None: """ Args: @@ -99,6 +100,11 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + # pixelwise: bool = False, + self.pixelwise = pixelwise + if pixelwise: + if self.reduction != LossReduction.NONE.value: + raise ValueError('Can only compute pixelwise loss when reduction is "none"') def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -111,6 +117,50 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + Example: + >>> import torch + >>> from monai.losses.dice import DiceLoss + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = DiceLoss(reduction='none', pixelwise=True) + >>> loss = self(input, target) + >>> assert loss.shape == input.shape + + >>> # Original reduction=None behavior the spacetime dimensions + >>> # are always reduced + >>> self = DiceLoss(reduction='none', pixelwise=False, batch=False) + >>> loss = self(input, target) + >>> assert tuple(loss.shape) == (B, C, 1, 1) + >>> self = DiceLoss(reduction='none', pixelwise=False, batch=True) + >>> loss = self(input, target) + >>> assert tuple(loss.shape) == (1, C, 1, 1) + + >>> # Test that pixelwise variants of reduce=none correspond with a reduction mode + >>> r0 = DiceLoss(reduction='sum', batch=False)(input, target) + >>> r1 = DiceLoss(reduction='none', batch=False, pixelwise=True)(input, target).sum() + >>> r2 = DiceLoss(reduction='none', batch=False, pixelwise=False)(input, target).sum() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceLoss(reduction='sum', batch=True)(input, target) + >>> r1 = DiceLoss(reduction='none', batch=True, pixelwise=True)(input, target).sum() + >>> r2 = DiceLoss(reduction='none', batch=True, pixelwise=False)(input, target).sum() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceLoss(reduction='mean', batch=False)(input, target) + >>> r1 = DiceLoss(reduction='none', batch=False, pixelwise=True)(input, target).sum((2, 3)).mean() + >>> r2 = DiceLoss(reduction='none', batch=False, pixelwise=False)(input, target).mean() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceLoss(reduction='mean', batch=True)(input, target) + >>> r1 = DiceLoss(reduction='none', batch=True, pixelwise=True)(input, target).sum((0, 2, 3)).mean() + >>> r2 = DiceLoss(reduction='none', batch=True, pixelwise=False)(input, target).mean() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) """ if self.sigmoid: input = torch.sigmoid(input) @@ -148,28 +198,57 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, dim=reduce_axis) - if self.squared_pred: target = torch.pow(target, 2) input = torch.pow(input, 2) - ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(input, dim=reduce_axis) + ground_o = torch.sum(target, dim=reduce_axis, keepdim=True) + pred_o = torch.sum(input, dim=reduce_axis, keepdim=True) + + union = ground_o + pred_o + + if self.pixelwise and self.reduction == LossReduction.NONE.value: + intersection = target * input - denominator = ground_o + pred_o + if self.jaccard: + denominator = 2.0 * (union - intersection.sum(dim=reduce_axis, keepdim=True)) + else: + denominator = union - if self.jaccard: - denominator = 2.0 * (denominator - intersection) + if self.batch: + nitems = np.prod(intersection.shape[2:]) * intersection.shape[0] + else: + nitems = np.prod(intersection.shape[2:]) - f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) + split_smooth_nr = self.smooth_nr / nitems + numer_split = 2.0 * intersection + split_smooth_nr + denom_split = denominator + self.smooth_dr - if self.reduction == LossReduction.MEAN.value: - f = torch.mean(f) # the batch and channel average - elif self.reduction == LossReduction.SUM.value: - f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + lead_split = 1 / nitems + f: torch.Tensor = lead_split - numer_split / denom_split + else: + intersection = torch.sum(target * input, dim=reduce_axis, keepdim=True) + + if self.jaccard: + denominator = 2.0 * (union - intersection) + else: + denominator = union + + numer = 2.0 * intersection + self.smooth_nr + denom = denominator + self.smooth_dr + f: torch.Tensor = 1.0 - numer / denom + + if self.reduction == LossReduction.MEAN.value: + f = torch.mean(f) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + pass + # f = torch.sum(f, dim=reduce_axis) + else: + raise ValueError( + f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' + ) return f @@ -224,6 +303,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + pixelwise: bool = False, ) -> None: """ Args: @@ -272,6 +352,11 @@ def __init__( self.smooth_dr = float(smooth_dr) self.batch = batch + self.pixelwise = pixelwise + if pixelwise: + if self.reduction != LossReduction.NONE.value: + raise ValueError('Can only compute pixelwise loss when reduction is "none"') + def w_func(self, grnd): if self.w_type == Weight.SIMPLE: return torch.reciprocal(grnd) @@ -288,6 +373,50 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + Example: + >>> import torch + >>> from monai.losses.dice import GeneralizedDiceLoss + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = GeneralizedDiceLoss(reduction='none', pixelwise=True) + >>> loss = self(input, target) + >>> assert loss.shape == input.shape + + >>> # Original reduction=None behavior the spacetime dimensions + >>> # are always reduced + >>> self = GeneralizedDiceLoss(reduction='none', pixelwise=False, batch=False) + >>> loss = self(input, target) + >>> assert tuple(loss.shape) == (B, 1, 1, 1) + >>> self = GeneralizedDiceLoss(reduction='none', pixelwise=False, batch=True) + >>> loss = self(input, target) + >>> assert tuple(loss.shape) == (1, C, 1, 1) + + >>> # Test that pixelwise variants of reduce=none correspond with a reduction mode + >>> r0 = GeneralizedDiceLoss(reduction='sum', batch=False)(input, target) + >>> r1 = GeneralizedDiceLoss(reduction='none', batch=False, pixelwise=True)(input, target).sum() + >>> r2 = GeneralizedDiceLoss(reduction='none', batch=False, pixelwise=False)(input, target).sum() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = GeneralizedDiceLoss(reduction='sum', batch=True)(input, target) + >>> r1 = GeneralizedDiceLoss(reduction='none', batch=True, pixelwise=True)(input, target).sum() + >>> r2 = GeneralizedDiceLoss(reduction='none', batch=True, pixelwise=False)(input, target).sum() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = GeneralizedDiceLoss(reduction='mean', batch=False)(input, target) + >>> r1 = GeneralizedDiceLoss(reduction='none', batch=False, pixelwise=True)(input, target).sum((1, 2, 3)).mean() + >>> r2 = GeneralizedDiceLoss(reduction='none', batch=False, pixelwise=False)(input, target).mean() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = GeneralizedDiceLoss(reduction='mean', batch=True)(input, target) + >>> r1 = GeneralizedDiceLoss(reduction='none', batch=True, pixelwise=True)(input, target).sum((0, 2, 3)).mean() + >>> r2 = GeneralizedDiceLoss(reduction='none', batch=True, pixelwise=False)(input, target).mean() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) """ if self.sigmoid: input = torch.sigmoid(input) @@ -322,29 +451,65 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, reduce_axis) - ground_o = torch.sum(target, reduce_axis) - pred_o = torch.sum(input, reduce_axis) + # The union will be part of the denominator + ground_o = torch.sum(target, reduce_axis, keepdim=True) + pred_o = torch.sum(input, reduce_axis, keepdim=True) + union = ground_o + pred_o - denominator = ground_o + pred_o - - w = self.w_func(ground_o.float()) + # Number of true voxels for each category in the truth + true_hist = torch.sum(target, reduce_axis, keepdim=True) + w = self.w_func(true_hist.float()) for b in w: infs = torch.isinf(b) b[infs] = 0.0 b[infs] = torch.max(b) - f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(0 if self.batch else 1) + self.smooth_nr) / ( - (denominator * w).sum(0 if self.batch else 1) + self.smooth_dr - ) + if self.pixelwise and self.reduction == LossReduction.NONE.value: + # The trick to reduce=none is to not reduce the numerator + # The computations are somewhat redundant and slower as compared + # to when reduce is mean or sum + intersection = target * input + + # Weight the numerator and denominator + w_intersection = intersection * w + w_union = (union * w).sum(0 if self.batch else 1, keepdim=True) - if self.reduction == LossReduction.MEAN.value: - f = torch.mean(f) # the batch and channel average - elif self.reduction == LossReduction.SUM.value: - f = torch.sum(f) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + # Split the numerator smooth term across voxels + split_smooth_nr = self.smooth_nr / w_intersection.numel() + + numer_split = 2.0 * w_intersection + split_smooth_nr + denom_split = w_union + self.smooth_dr + + if self.batch: + nitems = np.prod(numer_split.shape[2:]) * numer_split.shape[0] + else: + nitems = np.prod(numer_split.shape[1:]) + + lead_split = 1 / nitems + f: torch.Tensor = lead_split - numer_split / denom_split + else: + # When reduction is not None, we can be more efficient + intersection = torch.sum(target * input, reduce_axis, keepdim=True) + + w_intersection = (intersection * w).sum(0 if self.batch else 1, keepdim=True) + w_union = (union * w).sum(0 if self.batch else 1, keepdim=True) + + numer = 2.0 * w_intersection + self.smooth_nr + denom = w_union + self.smooth_dr + f: torch.Tensor = 1.0 - numer / denom + + if self.reduction == LossReduction.MEAN.value: + f = torch.mean(f) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + pass + # f = torch.sum(f, dim=reduce_axis) + else: + raise ValueError( + f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' + ) return f @@ -373,6 +538,7 @@ def __init__( reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, + pixelwise: bool = False, ) -> None: """ Args: @@ -437,49 +603,124 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) + self.pixelwise = pixelwise + if pixelwise: + if self.reduction != LossReduction.NONE.value: + raise ValueError('Can only compute pixelwise loss when reduction is "none"') + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. + Example: + >>> from monai.losses.dice import GeneralizedWassersteinDiceLoss + >>> import torch + >>> B, C, H, W = 7, 5, 3, 2 + >>> input = torch.rand(B, C, H, W) + + >>> dist_matrix = 1 - torch.eye(C, C) # this dist matrix reduces to soft dice score + + >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long() + >>> #target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> target = target_idx + >>> self = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', pixelwise=True) + >>> loss = self(input, target) + >>> assert tuple(loss.shape) == (B, 1, H, W) + + >>> # Original reduction=None behavior the spacetime dimensions + >>> # are always reduced + >>> self = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', pixelwise=False) + >>> loss = self(input, target) + >>> assert tuple(loss.shape) == (B, 1, 1, 1) + + >>> # Test that pixelwise variants of reduce=none correspond with a reduction mode + >>> r0 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='sum')(input, target) + >>> r1 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', pixelwise=True)(input, target).sum() + >>> r2 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', pixelwise=False)(input, target).sum() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='mean')(input, target) + >>> r1 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', pixelwise=True)(input, target).sum((1, 2, 3)).mean() + >>> r2 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', pixelwise=False)(input, target).mean() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> # Test that pixelwise variants of reduce=none correspond with a reduction mode + >>> r0 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='sum', weighting_mode='GDL')(input, target) + >>> r1 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', weighting_mode='GDL', pixelwise=True)(input, target).sum() + >>> r2 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', weighting_mode='GDL', pixelwise=False)(input, target).sum() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + + >>> r0 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='mean', weighting_mode='GDL')(input, target) + >>> r1 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', weighting_mode='GDL', pixelwise=True)(input, target).sum((1, 2, 3)).mean() + >>> r2 = GeneralizedWassersteinDiceLoss(dist_matrix, reduction='none', weighting_mode='GDL', pixelwise=False)(input, target).mean() + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) + >>> assert torch.allclose(r0, r2, rtol=1e-3, atol=1e-6) """ + # Input shape is Batch, Classes, followed by spacetime dimensions + B, C, *ST_DIMS = input.shape + ST = np.prod(ST_DIMS) + # Aggregate spatial dimensions - flat_input = input.reshape(input.size(0), input.size(1), -1) - flat_target = target.reshape(target.size(0), -1).long() + flat_input = input.reshape(B, C, ST) + flat_target = target.reshape(B, ST).long() # Apply the softmax to the input scores map - probs = F.softmax(flat_input, dim=1) + probs = F.softmax(flat_input, dim=1) # [B, C, ST] # Compute the Wasserstein distance map - wass_dist_map = self.wasserstein_distance_map(probs, flat_target) + wass_dist_map = self.wasserstein_distance_map(probs, flat_target) # [B, ST] + + # Compute the values of alpha to use based on :attr:`self.alpha_mode` + alpha = self._compute_alpha_generalized_true_positives(flat_target) # [B, C] - # Compute the values of alpha to use - alpha = self._compute_alpha_generalized_true_positives(flat_target) + true_pos_split = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map) # [B, ST] + # Aggregate true pos over spatial dims + true_pos = true_pos_split.sum(dim=1) # [B] # Compute the numerator and denominator of the generalized Wasserstein Dice loss if self.alpha_mode == "GDL": # use GDL-style alpha weights (i.e. normalize by the volume of each class) - # contrary to the original definition we also use alpha in the "generalized all error". - true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map) - denom = self._compute_denominator(alpha, flat_target, wass_dist_map) + denom_split = self._compute_denominator(alpha, flat_target, wass_dist_map) # [B, ST] + denom = denom_split.sum(dim=1) # [B] else: # default: as in the original paper - # (i.e. alpha=1 for all foreground classes and 0 for the background). - # Compute the generalised number of true positives - true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map) - all_error = torch.sum(wass_dist_map, dim=1) - denom = 2 * true_pos + all_error - - # Compute the final loss - wass_dice: torch.Tensor = (2.0 * true_pos + self.smooth_nr) / (denom + self.smooth_dr) - wass_dice_loss: torch.Tensor = 1.0 - wass_dice - - if self.reduction == LossReduction.MEAN.value: - wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average - elif self.reduction == LossReduction.SUM.value: - wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims - elif self.reduction != LossReduction.NONE.value: - raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + all_error = torch.sum(wass_dist_map, dim=1) # [B] + denom = 2 * true_pos + all_error # [B] + + if self.pixelwise and self.reduction == LossReduction.NONE.value: + # Dont reduce over the spatial resolution + smooth_nr_split = self.smooth_nr / ST + numer_split = 2.0 * true_pos_split + smooth_nr_split + wass_dice_split: torch.Tensor = numer_split / (denom[:, None] + self.smooth_dr) + lead_split = 1 / ST + wass_dice_loss_flat = lead_split - wass_dice_split + # reshape back to spatial dims, categories are always reduced. + wass_dice_loss = wass_dice_loss_flat.view(B, 1, *ST_DIMS) + else: + # Compute the final loss + numer = 2.0 * true_pos + self.smooth_nr + wass_dice = numer / (denom + self.smooth_dr) + wass_dice_loss_flat = 1.0 - wass_dice + + if self.reduction == LossReduction.MEAN.value: + wass_dice_loss: torch.Tensor = torch.mean(wass_dice_loss_flat) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + wass_dice_loss: torch.Tensor = torch.sum(wass_dice_loss_flat) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = [B, 1] + ([1] * len(ST_DIMS)) + wass_dice_loss: torch.Tensor = wass_dice_loss_flat.view(broadcast_shape) + else: + raise ValueError( + f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' + ) return wass_dice_loss @@ -522,44 +763,47 @@ def wasserstein_distance_map(self, flat_proba: torch.Tensor, flat_target: torch. return wasserstein_map def _compute_generalized_true_positive( - self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor + self, alpha: torch.Tensor, flat_target: torch.Tensor, wass_dist_map: torch.Tensor ) -> torch.Tensor: """ Args: alpha: generalised number of true positives of target class. - flat_target: the target tensor. - wasserstein_distance_map: the map obtained from the above function. + shape is (B, C) + + flat_target: the target tensor of class indexes. + shape is (B, ST) + + wass_dist_map: the map obtained from the above function. + shape is (B, ST) """ # Extend alpha to a map and select value at each voxel according to flat_target - alpha_extended = torch.unsqueeze(alpha, dim=2) - alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) + alpha_lut = torch.unsqueeze(alpha, dim=2) + alpha_lut = alpha_lut.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) # [B, C, ST] flat_target_extended = torch.unsqueeze(flat_target, dim=1) - alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - - return torch.sum( - alpha_extended * (1.0 - wasserstein_distance_map), - dim=[1, 2], - ) + alpha_extended = torch.gather(alpha_lut, index=flat_target_extended, dim=1) # [B, 1, ST] + wass_sim = (1.0 - wass_dist_map)[None, :] # [1, B, ST] + prod = alpha_extended * wass_sim # [B, B, ST] + true_pos_split = torch.sum(prod, dim=[1]) # [B, ST] + return true_pos_split def _compute_denominator( - self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor + self, alpha: torch.Tensor, flat_target: torch.Tensor, wass_dist_map: torch.Tensor ) -> torch.Tensor: """ Args: alpha: generalised number of true positives of target class. flat_target: the target tensor. - wasserstein_distance_map: the map obtained from the above function. + wass_dist_map: the map obtained from the above function. """ # Extend alpha to a map and select value at each voxel according to flat_target - alpha_extended = torch.unsqueeze(alpha, dim=2) - alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) + alpha_lut = torch.unsqueeze(alpha, dim=2) + alpha_lut = alpha_lut.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) flat_target_extended = torch.unsqueeze(flat_target, dim=1) - alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - - return torch.sum( - alpha_extended * (2.0 - wasserstein_distance_map), - dim=[1, 2], - ) + alpha_extended = torch.gather(alpha_lut, index=flat_target_extended, dim=1) # [B, 1, ST] + wass_2sim = (2.0 - wass_dist_map)[None, :] # [1, B, ST] + prod = alpha_extended * wass_2sim # [B, B, ST] + denom_split = torch.sum(prod, dim=[1]) # [B, ST] + return denom_split def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor: """ @@ -568,13 +812,16 @@ def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) - """ alpha: torch.Tensor = torch.ones((flat_target.size(0), self.num_classes)).float().to(flat_target.device) if self.alpha_mode == "GDL": # GDL style - # Define alpha like in the generalized dice loss - # i.e. the inverse of the volume of each class. + # use GDL-style alpha weights (i.e. normalize by the volume of each class) + # contrary to the original definition we also use alpha in the "generalized all error". one_hot_f = F.one_hot(flat_target, num_classes=self.num_classes).permute(0, 2, 1).float() volumes = torch.sum(one_hot_f, dim=2) alpha = 1.0 / (volumes + 1.0) - else: # default, i.e. like in the original paper - # alpha weights are 0 for the background and 1 the other classes + else: + # default, i.e. like in the original paper + # (i.e. alpha=1 for all foreground classes and 0 for the background). + # Compute the generalised number of true positives + # TODO: parametarize background index alpha[:, 0] = 0.0 return alpha @@ -605,6 +852,7 @@ def __init__( ce_weight: Optional[torch.Tensor] = None, lambda_dice: float = 1.0, lambda_ce: float = 1.0, + balance_broadcast: bool = False, ) -> None: """ Args: @@ -668,6 +916,14 @@ def __init__( self.lambda_dice = lambda_dice self.lambda_ce = lambda_ce + self.reduction = reduction + + # Better Name? + self.balance_broadcast = balance_broadcast + if balance_broadcast: + if self.reduction != LossReduction.NONE.value: + raise ValueError('Can only compute balance_broadcast loss when reduction is "none"') + def ce(self, input: torch.Tensor, target: torch.Tensor): """ Compute CrossEntropy loss for the input and target. @@ -694,12 +950,70 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. + Example: + >>> import torch + >>> from monai.losses.dice import DiceLoss + >>> C = 5 + >>> input = torch.rand(7, 5, 3, 2) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(7, 3, 2)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = DiceCELoss(reduction='none') + >>> loss = self(input, target) + >>> assert loss.shape == input.shape + + >>> # Test that pixelwise variants of reduce=none correspond with a reduction mode + >>> r0 = DiceCELoss(reduction='sum', batch=False)(input, target) + >>> r1 = DiceCELoss(reduction='none', batch=False, balance_broadcast=True)(input, target).sum() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceCELoss(reduction='sum', batch=True)(input, target) + >>> r1 = DiceCELoss(reduction='none', batch=True, balance_broadcast=True)(input, target).sum() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceCELoss(reduction='mean', batch=False)(input, target) + >>> r1 = DiceCELoss(reduction='none', batch=False, balance_broadcast=False)(input, target).mean() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceCELoss(reduction='mean', batch=True)(input, target) + >>> r1 = DiceCELoss(reduction='none', batch=True, balance_broadcast=False)(input, target).mean() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) """ if len(input.shape) != len(target.shape): raise ValueError("the number of dimensions for input and target should be the same.") dice_loss = self.dice(input, target) ce_loss = self.ce(input, target) + + if self.reduction == LossReduction.NONE.value: + # Expand the class dimension for reduction=none compatability + ce_loss = ce_loss[:, None, ...] + + # If we want to apply "mean" reduction to a "none" reduced + # item after the fact, balance_broadcast must be False, + # and for "sum", balance_broadcast must be True. + if self.balance_broadcast: + # Broadcasting will introduce duplicates of items, so we have to + # componestate for that. This does cause "mean" reduction + # to compoenstate + nitems_final = np.prod(torch.broadcast_shapes(dice_loss.shape, ce_loss.shape)) + + nitems_dice = np.prod(dice_loss.shape) + nitems_ce = np.prod(ce_loss.shape) + + dice_bcast_factor = nitems_final // nitems_dice + ce_bcase_factor = nitems_final // nitems_ce + + dice_loss = dice_loss * (1.0 / dice_bcast_factor) + ce_loss = ce_loss * (1.0 / ce_bcase_factor) + total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss return total_loss @@ -730,6 +1044,7 @@ def __init__( focal_weight: Optional[Union[Sequence[float], float, int, torch.Tensor]] = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, + balance_broadcast: bool = False, ) -> None: """ Args: @@ -796,6 +1111,13 @@ def __init__( raise ValueError("lambda_focal should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_focal = lambda_focal + self.reduction = reduction + + # Better Name? + self.balance_broadcast = balance_broadcast + if balance_broadcast: + if self.reduction != LossReduction.NONE.value: + raise ValueError('Can only compute balance_broadcast loss when reduction is "none"') def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -808,12 +1130,67 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. + Example: + >>> import torch + >>> from monai.losses.dice import DiceLoss + >>> C = 5 + >>> input = torch.rand(7, 5, 3, 2) + >>> target_idx = torch.randint(low=0, high=C - 1, size=(7, 3, 2)).long() + >>> target = one_hot(target_idx[:, None, ...], num_classes=C) + >>> self = DiceFocalLoss(reduction='none') + >>> loss = self(input, target) + >>> assert loss.shape == input.shape + + >>> # Test that pixelwise variants of reduce=none correspond with a reduction mode + >>> r0 = DiceFocalLoss(reduction='sum', batch=False)(input, target) + >>> r1 = DiceFocalLoss(reduction='none', batch=False, balance_broadcast=True)(input, target).sum() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceFocalLoss(reduction='sum', batch=True)(input, target) + >>> r1 = DiceFocalLoss(reduction='none', batch=True, balance_broadcast=True)(input, target).sum() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceFocalLoss(reduction='mean', batch=False)(input, target) + >>> r1 = DiceFocalLoss(reduction='none', batch=False)(input, target).mean() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) + + >>> r0 = DiceFocalLoss(reduction='mean', batch=True)(input, target) + >>> r1 = DiceFocalLoss(reduction='none', batch=True)(input, target).mean() + >>> print('r0 = {!r}'.format(r0)) + >>> print('r1 = {!r}'.format(r1)) + >>> assert torch.allclose(r0, r1, rtol=1e-3, atol=1e-6) """ if len(input.shape) != len(target.shape): raise ValueError("the number of dimensions for input and target should be the same.") dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) + + if self.reduction == LossReduction.NONE.value: + # If we want to apply "mean" reduction to a "none" reduced + # item after the fact, balance_broadcast must be False, + # and for "sum", balance_broadcast must be True. + if self.balance_broadcast: + # Broadcasting will introduce duplicates of items, so we have to + # componestate for that. This does cause "mean" reduction + # to compoenstate + nitems_final = np.prod(torch.broadcast_shapes(dice_loss.shape, focal_loss.shape)) + + nitems_dice = np.prod(dice_loss.shape) + nitems_focal = np.prod(focal_loss.shape) + + dice_bcast_factor = nitems_final // nitems_dice + focal_bcase_factor = nitems_final // nitems_focal + + dice_loss = dice_loss * (1.0 / dice_bcast_factor) + focal_loss = focal_loss * (1.0 / focal_bcase_factor) + total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss return total_loss diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 157ce9fd01..120e373e86 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -89,6 +89,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: number of classes. ValueError: When ``self.weight`` is/contains a value that is less than 0. + Example: + >>> import torch + >>> from monai.losses.focal_loss import FocalLoss + >>> input = torch.rand(7, 5, 3, 2) + >>> target = torch.rand(7, 5, 3, 2) + >>> self = FocalLoss(reduction='none') + >>> loss = self(input, target) + >>> assert loss.shape == input.shape + """ n_pred_ch = input.shape[1] @@ -147,12 +156,32 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # Compute the loss mini-batch. # (1-p_t)^gamma * log(p_t) with reduced chance of overflow p = F.logsigmoid(-i * (t * 2.0 - 1.0)) - loss = torch.mean((p * self.gamma).exp() * ce, dim=-1) + loss = (p * self.gamma).exp() * ce if self.reduction == LossReduction.SUM.value: return loss.sum() if self.reduction == LossReduction.NONE.value: - return loss + voxel_dims = list(input.shape[2:]) + # Hack for JIT, which requires static parsing of reshape + if len(voxel_dims) == 1: + (d1,) = voxel_dims + return loss.reshape(b, n, d1) + elif len(voxel_dims) == 2: + d1, d2 = voxel_dims + return loss.reshape(b, n, d1, d2) + elif len(voxel_dims) == 3: + d1, d2, d3 = voxel_dims + return loss.reshape(b, n, d1, d2, d3) + elif len(voxel_dims) == 4: + d1, d2, d3, d4 = voxel_dims + return loss.reshape(b, n, d1, d2, d3, d4) + elif len(voxel_dims) == 5: + d1, d2, d3, d4, d5 = voxel_dims + return loss.reshape(b, n, d1, d2, d3, d4, d5) + else: + # JIT prevents use from coding the general case + # return loss.reshape(b, n, *voxel_dims) + raise NotImplementedError if self.reduction == LossReduction.MEAN.value: return loss.mean() raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 1314fe3841..bec3a6fb84 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -44,6 +44,35 @@ def test_consistency_with_cross_entropy_2d(self): max_error = abs(a - b) self.assertAlmostEqual(max_error, 0.0, places=3) + def test_consistency_with_cross_entropy_2d_no_reduction(self): + """For gamma=0 the focal loss reduces to the cross entropy loss""" + import numpy as np + + focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="none", weight=1.0) + ce = nn.BCEWithLogitsLoss(reduction="none") + max_error = 0 + class_num = 10 + batch_size = 128 + for _ in range(100): + # Create a random tensor of shape (batch_size, class_num, 8, 4) + x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) + # Create a random batch of classes + l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float() + if torch.cuda.is_available(): + x = x.cuda() + l = l.cuda() + output0 = focal_loss(x, l) + output1 = ce(x, l) + a = output0.cpu().detach().numpy() + b = output1.cpu().detach().numpy() + error = np.abs(a - b) + max_error = np.maximum(error, max_error) + # if np.all(np.abs(a - b) > max_error): + # max_error = np.abs(a - b) + + assert np.allclose(max_error, 0) + # self.assertAlmostEqual(max_error, 0.0, places=3) + def test_consistency_with_cross_entropy_2d_onehot_label(self): """For gamma=0 the focal loss reduces to the cross entropy loss""" focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean")