From b24454eceec062b0c20153ac19ac58e6415bf9c0 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Wed, 25 Aug 2021 17:38:33 +0100 Subject: [PATCH 1/7] 2755 Add b-spline kernel option Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 104 +++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 9 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 9eb4540a6b..ad6c3102f1 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -177,6 +177,7 @@ class GlobalMutualInformationLoss(_Loss): def __init__( self, + kernel_type: str = "gaussian", num_bins: int = 23, sigma_ratio: float = 0.5, reduction: Union[LossReduction, str] = LossReduction.MEAN, @@ -185,6 +186,19 @@ def __init__( ) -> None: """ Args: + kernel_type: {``"gaussian"``, ``"spline"``} + - ``"gaussian"``: adapted from DeepReg + Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1 + - ``"spline"``: based on the method of Mattes et al [1,2] and adapted from ITK + References: + [1] "Nonrigid multimodality image registration" + D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank + Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620. + [2] "PET-CT Image Registration in the Chest Using Free-form Deformations" + D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank + IEEE Transactions in Medical Imaging. Vol.22, No.1, + January 2003. pp.120-128. + num_bins: number of bins for intensity sigma_ratio: a hyper param for gaussian function reduction: {``"none"``, ``"mean"``, ``"sum"``} @@ -201,20 +215,93 @@ def __init__( raise ValueError("num_bins must > 0, got {num_bins}") bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio - self.preterm = 1 / (2 * sigma ** 2) - self.bin_centers = bin_centers[None, None, ...] + if kernel_type not in ["gaussian", "b-spline"]: + raise ValueError( + f'Unsupported kernel_type: {kernel_type}, available options are ["gaussian", "b-spline].' + ) + self.kernel_type = kernel_type + if self.kernel_type == "gaussian": + self.preterm = 1 / (2 * sigma ** 2) + self.bin_centers = bin_centers[None, None, ...] self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def parzen_windowing(self, pred: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + ]: + if self.kernel_type == "gaussian": + pred_weight, pred_probability = self.parzen_windowing_gaussian(pred) + target_weight, target_probability = self.parzen_windowing_gaussian(target) + elif self.kernel_type == "b-spline": + # a third order BSpline kernel is used for the pred image intensity PDF. + pred_weight, pred_probability = self.parzen_windowing_b_spline(pred, order=3) + # a zero order (box car) BSpline kernel is used for the target image intensity PDF. + target_weight, target_probability = self.parzen_windowing_b_spline(target, order=0) + else: + raise ValueError + return pred_weight, pred_probability, target_weight, target_probability + + def parzen_windowing_b_spline(self, img: torch.Tensor, order) -> Tuple[torch.Tensor, torch.Tensor]: """ + Parzen windowing with b-spline kernel (adapted from ITK) Args: - pred: the shape should be B[NDHW]. + img: the shape should be B[NDHW]. + order: int. + """ + + # Compute binsize for the histograms. + # + # The binsize for the image intensities needs to be adjusted so that + # we can avoid dealing with boundary conditions using the cubic + # spline as the Parzen window. We do this by increasing the size + # of the bins so that the joint histogram becomes "padded" at the + # borders. Because we are changing the binsize, + # we also need to shift the minimum by the padded amount in order to + # avoid minimum values filling in our padded region. + # + # Note that there can still be non-zero bin values in the padded region, + # it's just that these bins will never be a central bin for the Parzen + # window. + max, min = torch.max(img), torch.min(img) + padding = 2 + bin_size = (max - min) // (self.num_bins - 2 * padding) + norm_min = min // bin_size - padding + + # assign bin/window index to each voxel + window_term = img // bin_size - norm_min # B[NDHW] + # make sure the extreme values are in valid (non-padded) bins + window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW] + window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1) + bins = torch.arange(self.num_bins).reshape(1, 1, -1) # (1, 1, num_bins) + sample_bin_matrix = torch.abs(bins - window_term) # (batch, num_sample, num_bins) + + # b-spleen kernel + # (4 - 6 * abs ** 2 + 3 * abs ** 3) / 6 when 0 <= abs < 1 + # (2 - abs) ** 3 / 6 when 1 <= abs < 2 + weight = torch.zeros_like(sample_bin_matrix, dtype=torch.float) # (batch, num_sample, num_bins) + if order == 1: + weight = weight + sample_bin_matrix == 0 + elif order == 3: + weight = weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix == 0) + weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix == 1) + else: + raise ValueError(f'Do not support b-spline {order}-order parzen windowing') + + weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bins) + probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bins) + return weight, probability + + def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parzen windowing with gaussian kernel (adapted from DeepReg implementation) + Note: the input is expected to range between 0 and 1 + Args: + img: the shape should be B[NDHW]. """ - pred = torch.clamp(pred, 0, 1) - pred = pred.reshape(pred.shape[0], -1, 1) # (batch, num_sample, 1) + img = torch.clamp(img, 0, 1) + img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1) weight = torch.exp( - -self.preterm.to(pred) * (pred - self.bin_centers.to(pred)) ** 2 + -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2 ) # (batch, num_sample, num_bin) weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) @@ -230,8 +317,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") - wa, pa = self.parzen_windowing(pred) # (batch, num_sample, num_bin), (batch, 1, num_bin) - wb, pb = self.parzen_windowing(target) # (batch, num_sample, num_bin), (batch, 1, num_bin) + wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin) pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) From edfede3cb6a2c854be9456a57dc6e583e036a107 Mon Sep 17 00:00:00 2001 From: kate-sann5100 Date: Thu, 9 Sep 2021 23:27:48 +0100 Subject: [PATCH 2/7] 2755 debug b-spline kernel Signed-off-by: kate-sann5100 --- monai/losses/image_dissimilarity.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index ad6c3102f1..af36f1323a 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -219,6 +219,7 @@ def __init__( raise ValueError( f'Unsupported kernel_type: {kernel_type}, available options are ["gaussian", "b-spline].' ) + self.num_bins = num_bins self.kernel_type = kernel_type if self.kernel_type == "gaussian": self.preterm = 1 / (2 * sigma ** 2) @@ -264,11 +265,11 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order) -> Tuple[torch.Ten # window. max, min = torch.max(img), torch.min(img) padding = 2 - bin_size = (max - min) // (self.num_bins - 2 * padding) - norm_min = min // bin_size - padding + bin_size = (max - min) / (self.num_bins - 2 * padding) + norm_min = torch.div(min, bin_size, rounding_mode='floor') - padding # assign bin/window index to each voxel - window_term = img // bin_size - norm_min # B[NDHW] + window_term = torch.div(img, bin_size, rounding_mode='floor') - norm_min # B[NDHW] # make sure the extreme values are in valid (non-padded) bins window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW] window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1) @@ -279,11 +280,11 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order) -> Tuple[torch.Ten # (4 - 6 * abs ** 2 + 3 * abs ** 3) / 6 when 0 <= abs < 1 # (2 - abs) ** 3 / 6 when 1 <= abs < 2 weight = torch.zeros_like(sample_bin_matrix, dtype=torch.float) # (batch, num_sample, num_bins) - if order == 1: - weight = weight + sample_bin_matrix == 0 + if order == 0: + weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: - weight = weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix == 0) - weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix == 1) + weight = weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix < 1) / 6 + weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: raise ValueError(f'Do not support b-spline {order}-order parzen windowing') @@ -318,9 +319,9 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != pred.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})") wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin) - pab = torch.bmm(wa.permute(0, 2, 1), wb).div(wa.shape[1]) # (batch, num_bins, num_bins) - papb = torch.bmm(pa.permute(0, 2, 1), pb) # (batch, num_bins, num_bins) + pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins) + papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins) mi = torch.sum( pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2) ) # (batch) From 6bebf0f0002e684efe9a19b81f6a918803a129df Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 10 Sep 2021 07:16:47 +0000 Subject: [PATCH 3/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/losses/image_dissimilarity.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index af36f1323a..b8e1537b1a 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -216,9 +216,7 @@ def __init__( bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio if kernel_type not in ["gaussian", "b-spline"]: - raise ValueError( - f'Unsupported kernel_type: {kernel_type}, available options are ["gaussian", "b-spline].' - ) + raise ValueError(f'Unsupported kernel_type: {kernel_type}, available options are ["gaussian", "b-spline].') self.num_bins = num_bins self.kernel_type = kernel_type if self.kernel_type == "gaussian": @@ -227,9 +225,9 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) - def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor - ]: + def parzen_windowing( + self, pred: torch.Tensor, target: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if self.kernel_type == "gaussian": pred_weight, pred_probability = self.parzen_windowing_gaussian(pred) target_weight, target_probability = self.parzen_windowing_gaussian(target) @@ -266,14 +264,14 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order) -> Tuple[torch.Ten max, min = torch.max(img), torch.min(img) padding = 2 bin_size = (max - min) / (self.num_bins - 2 * padding) - norm_min = torch.div(min, bin_size, rounding_mode='floor') - padding + norm_min = torch.div(min, bin_size, rounding_mode="floor") - padding # assign bin/window index to each voxel - window_term = torch.div(img, bin_size, rounding_mode='floor') - norm_min # B[NDHW] + window_term = torch.div(img, bin_size, rounding_mode="floor") - norm_min # B[NDHW] # make sure the extreme values are in valid (non-padded) bins window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW] window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1) - bins = torch.arange(self.num_bins).reshape(1, 1, -1) # (1, 1, num_bins) + bins = torch.arange(self.num_bins).reshape(1, 1, -1) # (1, 1, num_bins) sample_bin_matrix = torch.abs(bins - window_term) # (batch, num_sample, num_bins) # b-spleen kernel @@ -283,10 +281,12 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order) -> Tuple[torch.Ten if order == 0: weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5 elif order == 3: - weight = weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix < 1) / 6 + weight = ( + weight + (4 - 6 * sample_bin_matrix ** 2 + 3 * sample_bin_matrix ** 3) * (sample_bin_matrix < 1) / 6 + ) weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6 else: - raise ValueError(f'Do not support b-spline {order}-order parzen windowing') + raise ValueError(f"Do not support b-spline {order}-order parzen windowing") weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bins) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bins) From e5c6db061ab4b8c26f4209621fb0f99e1129df05 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 4 Oct 2021 17:59:20 +0100 Subject: [PATCH 4/7] adds tests Signed-off-by: Wenqi Li --- monai/losses/image_dissimilarity.py | 30 +++++++++----------- tests/test_global_mutual_information_loss.py | 25 ++++++++++++++++ tests/test_reg_loss_integration.py | 2 ++ 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index b8e1537b1a..5233f75b00 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -16,6 +16,7 @@ from monai.networks.layers import gaussian_1d, separable_filtering from monai.utils import LossReduction, deprecated_arg +from monai.utils.module import look_up_option def make_rectangular_kernel(kernel_size: int) -> torch.Tensor: @@ -92,18 +93,15 @@ def __init__( if ndim is not None: spatial_dims = ndim self.ndim = spatial_dims - if self.ndim not in [1, 2, 3]: + if self.ndim not in {1, 2, 3}: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if self.kernel_size % 2 == 0: raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") - if kernel_type not in kernel_dict.keys(): - raise ValueError( - f'Unsupported kernel_type: {kernel_type}, available options are ["rectangular", "triangular", "gaussian"].' - ) - self.kernel = kernel_dict[kernel_type](self.kernel_size) + _kernel = look_up_option(kernel_type, kernel_dict) + self.kernel = _kernel(self.kernel_size) self.kernel_vol = self.get_kernel_vol() self.smooth_nr = float(smooth_nr) @@ -186,15 +184,15 @@ def __init__( ) -> None: """ Args: - kernel_type: {``"gaussian"``, ``"spline"``} - - ``"gaussian"``: adapted from DeepReg - Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1 - - ``"spline"``: based on the method of Mattes et al [1,2] and adapted from ITK + kernel_type: {``"gaussian"``, ``"b-spline"``} + ``"gaussian"``: adapted from DeepReg + Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1. + ``"b-spline"``: based on the method of Mattes et al [1,2] and adapted from ITK References: - [1] "Nonrigid multimodality image registration" + [1] "Nonrigid multimodality image registration" D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620. - [2] "PET-CT Image Registration in the Chest Using Free-form Deformations" + [2] "PET-CT Image Registration in the Chest Using Free-form Deformations" D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank IEEE Transactions in Medical Imaging. Vol.22, No.1, January 2003. pp.120-128. @@ -215,8 +213,7 @@ def __init__( raise ValueError("num_bins must > 0, got {num_bins}") bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio - if kernel_type not in ["gaussian", "b-spline"]: - raise ValueError(f'Unsupported kernel_type: {kernel_type}, available options are ["gaussian", "b-spline].') + self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type if self.kernel_type == "gaussian": @@ -240,9 +237,10 @@ def parzen_windowing( raise ValueError return pred_weight, pred_probability, target_weight, target_probability - def parzen_windowing_b_spline(self, img: torch.Tensor, order) -> Tuple[torch.Tensor, torch.Tensor]: + def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Parzen windowing with b-spline kernel (adapted from ITK) + Args: img: the shape should be B[NDHW]. order: int. @@ -267,7 +265,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order) -> Tuple[torch.Ten norm_min = torch.div(min, bin_size, rounding_mode="floor") - padding # assign bin/window index to each voxel - window_term = torch.div(img, bin_size, rounding_mode="floor") - norm_min # B[NDHW] + window_term = torch.div(img, bin_size) - norm_min # B[NDHW] # make sure the extreme values are in valid (non-padded) bins window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW] window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1) diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index 3373b59621..451ecb3b9b 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -45,6 +45,31 @@ }, -1.083999, ], + [ + {"kernel_type": "b-spline"}, + { + "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3), + }, + -1.0986018, + ], + [ + {"kernel_type": "b-spline"}, + { + "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3) + ** 2, + }, + -1.09861, + ], [ {}, { diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 36e3a460a5..bcd473516e 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -36,6 +36,7 @@ ["pred", "target"], ], [GlobalMutualInformationLoss, {"num_bins": 10}, ["pred", "target"]], + [GlobalMutualInformationLoss, {"kernel_type": "b-spline", "num_bins": 10}, ["pred", "target"]], ] @@ -98,6 +99,7 @@ def forward(self, x): loss_input = {"pred": output, "target": target} loss_val = loss(**{k: loss_input[k] for k in forward_args}) + print(loss_val) if it == 0: init_loss = loss_val From 19236cea5f41fe6aafa90f9e75332081158320f7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 4 Oct 2021 18:01:50 +0100 Subject: [PATCH 5/7] remove debug print Signed-off-by: Wenqi Li --- tests/test_reg_loss_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index bcd473516e..b9b730f731 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -99,7 +99,6 @@ def forward(self, x): loss_input = {"pred": output, "target": target} loss_val = loss(**{k: loss_input[k] for k in forward_args}) - print(loss_val) if it == 0: init_loss = loss_val From bebe00c80bd2ba0ca534465e66709bc9eb3c4551 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 4 Oct 2021 20:43:43 +0100 Subject: [PATCH 6/7] fixes tests Signed-off-by: Wenqi Li --- monai/losses/image_dissimilarity.py | 6 +++--- tests/test_global_mutual_information_loss.py | 2 ++ tests/test_reg_loss_integration.py | 2 ++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 5233f75b00..1092d666a1 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -259,10 +259,10 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torc # Note that there can still be non-zero bin values in the padded region, # it's just that these bins will never be a central bin for the Parzen # window. - max, min = torch.max(img), torch.min(img) + _max, _min = torch.max(img), torch.min(img) padding = 2 - bin_size = (max - min) / (self.num_bins - 2 * padding) - norm_min = torch.div(min, bin_size, rounding_mode="floor") - padding + bin_size = (_max - _min) / (self.num_bins - 2 * padding) + norm_min = torch.div(_min, bin_size, rounding_mode="floor") - padding # assign bin/window index to each voxel window_term = torch.div(img, bin_size) - norm_min # B[NDHW] diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index 451ecb3b9b..a688ea8394 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -16,6 +16,7 @@ from parameterized import parameterized from monai.losses.image_dissimilarity import GlobalMutualInformationLoss +from tests.utils import SkipIfBeforePyTorchVersion device = "cuda" if torch.cuda.is_available() else "cpu" @@ -110,6 +111,7 @@ class TestGlobalMutualInformationLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 9)) def test_shape(self, input_param, input_data, expected_val): result = GlobalMutualInformationLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index b9b730f731..1578aa4888 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -17,6 +17,7 @@ from parameterized import parameterized from monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss +from tests.utils import SkipIfBeforePyTorchVersion TEST_CASES = [ [BendingEnergyLoss, {}, ["pred"]], @@ -52,6 +53,7 @@ def tearDown(self): torch.backends.cudnn.benchmark = True @parameterized.expand(TEST_CASES) + @SkipIfBeforePyTorchVersion((1, 9)) def test_convergence(self, loss_type, loss_args, forward_args): """ The goal of this test is to assess if the gradient of the loss function From beac2ced445c4ed245b8e1880ed0754b7b432ec8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 4 Oct 2021 22:19:52 +0100 Subject: [PATCH 7/7] fixes gpu tests Signed-off-by: Wenqi Li --- monai/losses/image_dissimilarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index 1092d666a1..78f92303fc 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -269,7 +269,7 @@ def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> Tuple[torc # make sure the extreme values are in valid (non-padded) bins window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1) # B[NDHW] window_term = window_term.reshape(window_term.shape[0], -1, 1) # (batch, num_sample, 1) - bins = torch.arange(self.num_bins).reshape(1, 1, -1) # (1, 1, num_bins) + bins = torch.arange(self.num_bins, device=window_term.device).reshape(1, 1, -1) # (1, 1, num_bins) sample_bin_matrix = torch.abs(bins - window_term) # (batch, num_sample, num_bins) # b-spleen kernel