Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 102 additions & 17 deletions monai/losses/image_dissimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks.layers import gaussian_1d, separable_filtering
from monai.utils import LossReduction, deprecated_arg
from monai.utils.module import look_up_option


def make_rectangular_kernel(kernel_size: int) -> torch.Tensor:
Expand Down Expand Up @@ -92,18 +93,15 @@ def __init__(
if ndim is not None:
spatial_dims = ndim
self.ndim = spatial_dims
if self.ndim not in [1, 2, 3]:
if self.ndim not in {1, 2, 3}:
raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported")

self.kernel_size = kernel_size
if self.kernel_size % 2 == 0:
raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")

if kernel_type not in kernel_dict.keys():
raise ValueError(
f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].'
)
self.kernel = kernel_dict[kernel_type](self.kernel_size)
_kernel = look_up_option(kernel_type, kernel_dict)
self.kernel = _kernel(self.kernel_size)
self.kernel_vol = self.get_kernel_vol()

self.smooth_nr = float(smooth_nr)
Expand Down Expand Up @@ -177,6 +175,7 @@ class GlobalMutualInformationLoss(_Loss):

def __init__(
self,
kernel_type: str = "gaussian",
num_bins: int = 23,
sigma_ratio: float = 0.5,
reduction: Union[LossReduction, str] = LossReduction.MEAN,
Expand All @@ -185,6 +184,19 @@ def __init__(
) -> None:
"""
Args:
kernel_type: {``"gaussian"``, ``"b-spline"``}
``"gaussian"``: adapted from DeepReg
Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1.
``"b-spline"``: based on the method of Mattes et al [1,2] and adapted from ITK
References:
[1] "Nonrigid multimodality image registration"
D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank
Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620.
[2] "PET-CT Image Registration in the Chest Using Free-form Deformations"
D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank
IEEE Transactions in Medical Imaging. Vol.22, No.1,
January 2003. pp.120-128.

num_bins: number of bins for intensity
sigma_ratio: a hyper param for gaussian function
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Expand All @@ -201,20 +213,94 @@ def __init__(
raise ValueError("num_bins must > 0, got {num_bins}")
bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,)
sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
self.preterm = 1 / (2 * sigma ** 2)
self.bin_centers = bin_centers[None, None, ...]
self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"])
self.num_bins = num_bins
self.kernel_type = kernel_type
if self.kernel_type == "gaussian":
self.preterm = 1 / (2 * sigma ** 2)
self.bin_centers = bin_centers[None, None, ...]
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)

def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def parzen_windowing(
self, pred: torch.Tensor, target: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.kernel_type == "gaussian":
pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
target_weight, target_probability = self.parzen_windowing_gaussian(target)
elif self.kernel_type == "b-spline":
# a third order BSpline kernel is used for the pred image intensity PDF.
pred_weight, pred_probability = self.parzen_windowing_b_spline(pred, order=3)
# a zero order (box car) BSpline kernel is used for the target image intensity PDF.
target_weight, target_probability = self.parzen_windowing_b_spline(target, order=0)
else:
raise ValueError
return pred_weight, pred_probability, target_weight, target_probability

def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parzen windowing with b-spline kernel (adapted from ITK)

Args:
pred: the shape should be B[NDHW].
img: the shape should be B[NDHW].
order: int.
"""

# Compute binsize for the histograms.
#
# The binsize for the image intensities needs to be adjusted so that
# we can avoid dealing with boundary conditions using the cubic
# spline as the Parzen window. We do this by increasing the size
# of the bins so that the joint histogram becomes "padded" at the
# borders. Because we are changing the binsize,
# we also need to shift the minimum by the padded amount in order to
# avoid minimum values filling in our padded region.
#
# Note that there can still be non-zero bin values in the padded region,
# it's just that these bins will never be a central bin for the Parzen
# window.
_max, _min = torch.max(img), torch.min(img)
padding = 2
bin_size = (_max - _min) / (self.num_bins - 2 * padding)
norm_min = torch.div(_min, bin_size, rounding_mode="floor") - padding

# assign bin/window index to each voxel
window_term = torch.div(img, bin_size) - norm_min # B[NDHW]
Comment thread
wyli marked this conversation as resolved.
# make sure the extreme values are in valid (non-padded) bins
window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW]
window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1)
bins = torch.arange(self.num_bins, device=window_term.device).reshape(1, 1, -1) # (1, 1, num_bins)
sample_bin_matrix = torch.abs(bins - window_term) # (batch, num_sample, num_bins)

# b-spleen kernel
# (4 - 6 * abs ** 2 + 3 * abs ** 3) / 6 when 0 <= abs < 1
# (2 - abs) ** 3 / 6 when 1 <= abs < 2
weight = torch.zeros_like(sample_bin_matrix, dtype=torch.float) # (batch, num_sample, num_bins)
if order == 0:
weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5
elif order == 3:
weight = (
weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix < 1) / 6
)
weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6
else:
raise ValueError(f"Do not support b-spline {order}-order parzen windowing")

weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bins)
probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bins)
return weight, probability

def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Parzen windowing with gaussian kernel (adapted from DeepReg implementation)
Note: the input is expected to range between 0 and 1
Args:
img: the shape should be B[NDHW].
"""
pred = torch.clamp(pred, 0, 1)
pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1)
img = torch.clamp(img, 0, 1)
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
weight = torch.exp(
-self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2
-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2
) # (batch, num_sample, num_bin)
weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
Expand All @@ -230,11 +316,10 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
if target.shape != pred.shape:
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin)
wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin)
pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins)
wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin)

papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins)
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
mi = torch.sum(
pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)
) # (batch)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_global_mutual_information_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.losses.image_dissimilarity import GlobalMutualInformationLoss
from tests.utils import SkipIfBeforePyTorchVersion

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -45,6 +46,31 @@
},
-1.083999,
],
[
{"kernel_type": "b-spline"},
{
"pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None]
.expand(1, 3, 3, 3, 3)
.div(3),
"target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None]
.expand(1, 3, 3, 3, 3)
.div(3),
},
-1.0986018,
],
[
{"kernel_type": "b-spline"},
{
"pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None]
.expand(1, 3, 3, 3, 3)
.div(3),
"target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None]
.expand(1, 3, 3, 3, 3)
.div(3)
** 2,
},
-1.09861,
],
[
{},
{
Expand Down Expand Up @@ -85,6 +111,7 @@

class TestGlobalMutualInformationLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
@SkipIfBeforePyTorchVersion((1, 9))
def test_shape(self, input_param, input_data, expected_val):
result = GlobalMutualInformationLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_reg_loss_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from parameterized import parameterized

from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from tests.utils import SkipIfBeforePyTorchVersion

TEST_CASES = [
[BendingEnergyLoss, {}, ["pred"]],
Expand All @@ -36,6 +37,7 @@
["pred", "target"],
],
[GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]],
[GlobalMutualInformationLoss, {"kernel_type": "b-spline", "num_bins": 10}, ["pred", "target"]],
]


Expand All @@ -51,6 +53,7 @@ def tearDown(self):
torch.backends.cudnn.benchmark = True

@parameterized.expand(TEST_CASES)
@SkipIfBeforePyTorchVersion((1, 9))
def test_convergence(self, loss_type, loss_args, forward_args):
"""
The goal of this test is to assess if the gradient of the loss function
Expand Down