diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py index b229a0c08f..431167447b 100644 --- a/monai/losses/image_dissimilarity.py +++ b/monai/losses/image_dissimilarity.py @@ -129,11 +129,11 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: t2, p2, tp = target ** 2, pred ** 2, target * pred kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred) # sum over kernel - t_sum = separable_filtering(target, kernels=[kernel] * self.ndim) - p_sum = separable_filtering(pred, kernels=[kernel] * self.ndim) - t2_sum = separable_filtering(t2, kernels=[kernel] * self.ndim) - p2_sum = separable_filtering(p2, kernels=[kernel] * self.ndim) - tp_sum = separable_filtering(tp, kernels=[kernel] * self.ndim) + t_sum = separable_filtering(target, kernels=[kernel.to(pred)] * self.ndim) + p_sum = separable_filtering(pred, kernels=[kernel.to(pred)] * self.ndim) + t2_sum = separable_filtering(t2, kernels=[kernel.to(pred)] * self.ndim) + p2_sum = separable_filtering(p2, kernels=[kernel.to(pred)] * self.ndim) + tp_sum = separable_filtering(tp, kernels=[kernel.to(pred)] * self.ndim) # average over kernel t_avg = t_sum / kernel_vol diff --git a/monai/losses/multi_scale.py b/monai/losses/multi_scale.py index 5a17bc2d07..af23e03440 100644 --- a/monai/losses/multi_scale.py +++ b/monai/losses/multi_scale.py @@ -82,8 +82,8 @@ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: else: loss_list.append( self.loss( - separable_filtering(y_pred, [self.kernel_fn(s)] * (y_true.ndim - 2)), - separable_filtering(y_true, [self.kernel_fn(s)] * (y_true.ndim - 2)), + separable_filtering(y_pred, [self.kernel_fn(s).to(y_pred)] * (y_true.ndim - 2)), + separable_filtering(y_true, [self.kernel_fn(s).to(y_pred)] * (y_true.ndim - 2)), ) ) loss = torch.stack(loss_list, dim=0) diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index f2b9a41cae..8f1fb43535 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -17,30 +17,32 @@ from monai.losses.deform import BendingEnergyLoss +device = "cuda" if torch.cuda.is_available() else "cpu" + TEST_CASES = [ [ {}, - {"pred": torch.ones((1, 3, 5, 5, 5))}, + {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0, ], [ {}, - {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0, ], [ {}, - {"pred": torch.arange(0, 5)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0, ], [ {}, - {"pred": torch.arange(0, 5)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, 4.0, ], [ {}, - {"pred": torch.arange(0, 5)[None, None, :].expand(1, 3, 5) ** 2}, + {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 3, 5) ** 2}, 4.0, ], ] @@ -56,19 +58,19 @@ def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3))) + loss.forward(torch.ones((1, 3), device=device)) with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 5, 5, 5, 5))) + loss.forward(torch.ones((1, 3, 5, 5, 5, 5), device=device)) # spatial_dim < 5 with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 4, 5, 5))) + loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 4, 5))) with self.assertRaisesRegex(ValueError, ""): loss.forward(torch.ones((1, 3, 5, 5, 4))) def test_ill_opts(self): - pred = torch.rand(1, 3, 5, 5, 5) + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) with self.assertRaisesRegex(ValueError, ""): BendingEnergyLoss(reduction="unknown")(pred) with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py index 252a70e85e..3373b59621 100644 --- a/tests/test_global_mutual_information_loss.py +++ b/tests/test_global_mutual_information_loss.py @@ -17,20 +17,30 @@ from monai.losses.image_dissimilarity import GlobalMutualInformationLoss +device = "cuda" if torch.cuda.is_available() else "cpu" + TEST_CASES = [ [ {}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), + "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3), }, -1.0986018, ], [ {}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None, None].expand(1, 3, 3, 3, 3).div(3) + "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None, None] + .expand(1, 3, 3, 3, 3) + .div(3) ** 2, }, -1.083999, @@ -38,32 +48,35 @@ [ {}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None, None].expand(1, 3, 3, 3).div(3) ** 2, + "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None].expand(1, 3, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None, None] + .expand(1, 3, 3, 3) + .div(3) + ** 2, }, -1.083999, ], [ {}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :, None].expand(1, 3, 3).div(3) ** 2, + "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None].expand(1, 3, 3).div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :, None].expand(1, 3, 3).div(3) ** 2, }, -1.083999, ], [ {}, { - "pred": torch.arange(0, 3, dtype=torch.float)[None, :].div(3), - "target": torch.arange(0, 3, dtype=torch.float)[None, :].div(3) ** 2, + "pred": torch.arange(0, 3, dtype=torch.float, device=device)[None, :].div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device)[None, :].div(3) ** 2, }, -1.083999, ], [ {}, { - "pred": torch.arange(0, 3, dtype=torch.float).div(3), - "target": torch.arange(0, 3, dtype=torch.float).div(3) ** 2, + "pred": torch.arange(0, 3, dtype=torch.float, device=device).div(3), + "target": torch.arange(0, 3, dtype=torch.float, device=device).div(3) ** 2, }, -1.1920927e-07, ], @@ -79,13 +92,13 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = GlobalMutualInformationLoss() with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device)) with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)) + loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device)) def test_ill_opts(self): - pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) - target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) + pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device) + target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device) with self.assertRaisesRegex(ValueError, ""): GlobalMutualInformationLoss(num_bins=0)(pred, target) with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py index 8e9482596f..bddaedb54a 100644 --- a/tests/test_local_normalized_cross_correlation_loss.py +++ b/tests/test_local_normalized_cross_correlation_loss.py @@ -17,60 +17,89 @@ from monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss +device = "cuda" if torch.cuda.is_available() else "cpu" + TEST_CASES = [ [ {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular", "reduction": "sum"}, { - "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), - "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device), }, -1.0 * 3, ], [ {"in_channels": 1, "ndim": 1, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), - "target": torch.arange(0, 3).reshape(1, 1, -1).to(torch.float), + "pred": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device), + "target": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device), }, -1.0, ], [ {"in_channels": 1, "ndim": 2, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), - "target": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(torch.float), + "pred": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(dtype=torch.float, device=device), + "target": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(dtype=torch.float, device=device), }, -1.0, ], [ {"in_channels": 1, "ndim": 3, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), - "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 1, 3, 3, 3).to(torch.float), + "pred": torch.arange(0, 3) + .reshape(1, 1, -1, 1, 1) + .expand(1, 1, 3, 3, 3) + .to(dtype=torch.float, device=device), + "target": torch.arange(0, 3) + .reshape(1, 1, -1, 1, 1) + .expand(1, 1, 3, 3, 3) + .to(dtype=torch.float, device=device), }, -1.0, ], [ {"in_channels": 3, "ndim": 3, "kernel_type": "rectangular"}, { - "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), - "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, + "pred": torch.arange(0, 3) + .reshape(1, 1, -1, 1, 1) + .expand(1, 3, 3, 3, 3) + .to(dtype=torch.float, device=device), + "target": torch.arange(0, 3) + .reshape(1, 1, -1, 1, 1) + .expand(1, 3, 3, 3, 3) + .to(dtype=torch.float, device=device) + ** 2, }, -0.95801723, ], [ {"in_channels": 3, "ndim": 3, "kernel_type": "triangular", "kernel_size": 5}, { - "pred": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float), - "target": torch.arange(0, 5).reshape(1, 1, -1, 1, 1).expand(1, 3, 5, 5, 5).to(torch.float) ** 2, + "pred": torch.arange(0, 5) + .reshape(1, 1, -1, 1, 1) + .expand(1, 3, 5, 5, 5) + .to(dtype=torch.float, device=device), + "target": torch.arange(0, 5) + .reshape(1, 1, -1, 1, 1) + .expand(1, 3, 5, 5, 5) + .to(dtype=torch.float, device=device) + ** 2, }, -0.918672, ], [ {"in_channels": 3, "ndim": 3, "kernel_type": "gaussian"}, { - "pred": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float), - "target": torch.arange(0, 3).reshape(1, 1, -1, 1, 1).expand(1, 3, 3, 3, 3).to(torch.float) ** 2, + "pred": torch.arange(0, 3) + .reshape(1, 1, -1, 1, 1) + .expand(1, 3, 3, 3, 3) + .to(dtype=torch.float, device=device), + "target": torch.arange(0, 3) + .reshape(1, 1, -1, 1, 1) + .expand(1, 3, 3, 3, 3) + .to(dtype=torch.float, device=device) + ** 2, }, -0.95406944, ], @@ -87,13 +116,22 @@ def test_ill_shape(self): loss = LocalNormalizedCrossCorrelationLoss(in_channels=3, ndim=3) # in_channel unmatch with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 2, 3, 3, 3), dtype=torch.float), torch.ones((1, 2, 3, 3, 3), dtype=torch.float)) + loss.forward( + torch.ones((1, 2, 3, 3, 3), dtype=torch.float, device=device), + torch.ones((1, 2, 3, 3, 3), dtype=torch.float, device=device), + ) # ndim unmatch with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 3, 3), dtype=torch.float)) + loss.forward( + torch.ones((1, 3, 3, 3), dtype=torch.float, device=device), + torch.ones((1, 3, 3, 3), dtype=torch.float, device=device), + ) # pred, target shape unmatch with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 3, 3, 3), dtype=torch.float), torch.ones((1, 3, 4, 4, 4), dtype=torch.float)) + loss.forward( + torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device), + torch.ones((1, 3, 4, 4, 4), dtype=torch.float, device=device), + ) def test_ill_opts(self): pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float) diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py index 9ce1734e28..01a760db72 100644 --- a/tests/test_multi_scale.py +++ b/tests/test_multi_scale.py @@ -19,23 +19,30 @@ from tests.utils import SkipIfBeforePyTorchVersion, test_script_save dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) +device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASES = [ [ {"loss": dice_loss, "scales": None, "kernel": "gaussian"}, - {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + { + "y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device), + }, 0.307576, ], [ {"loss": dice_loss, "scales": [0, 1], "kernel": "gaussian"}, - {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, + { + "y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device), + "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device), + }, 0.463116, ], [ {"loss": dice_loss, "scales": [0, 1, 2], "kernel": "cauchy"}, { - "y_pred": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]]), - "y_true": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]]), + "y_pred": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]], device=device), + "y_true": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]], device=device), }, 0.715228, ], @@ -52,9 +59,13 @@ def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): MultiScaleLoss(loss=dice_loss, kernel="none") with self.assertRaisesRegex(ValueError, ""): - MultiScaleLoss(loss=dice_loss, scales=[-1])(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + MultiScaleLoss(loss=dice_loss, scales=[-1])( + torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device) + ) with self.assertRaisesRegex(ValueError, ""): - MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(torch.ones((1, 1, 3)), torch.ones((1, 1, 3))) + MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")( + torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device) + ) @SkipIfBeforePyTorchVersion((1, 7, 0)) def test_script(self):