Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/networks/blocks/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.nn.functional import softmax

from monai.networks.layers.filtering import PHLFilter
from monai.networks.utils import meshgrid_ij

__all__ = ["CRF"]

Expand Down Expand Up @@ -114,6 +115,6 @@ def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor):
# helper methods
def _create_coordinate_tensor(tensor):
axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())]
grids = torch.meshgrid(axes)
grids = meshgrid_ij(axes)
coords = torch.stack(grids).to(device=tensor.device, dtype=tensor.dtype)
return torch.stack(tensor.size(0) * [coords], dim=0)
3 changes: 2 additions & 1 deletion monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from monai.config.deviceconfig import USE_COMPILED
from monai.networks.layers.spatial_transforms import grid_pull
from monai.networks.utils import meshgrid_ij
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import

_C, _ = optional_import("monai._C")
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
@staticmethod
def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]
grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...)
grid = grid.to(ddf)
return grid
Expand Down
7 changes: 7 additions & 0 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"train_mode",
"copy_model_state",
"convert_to_torchscript",
"meshgrid_ij",
]


Expand Down Expand Up @@ -500,3 +501,9 @@ def convert_to_torchscript(
torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol)

return script_module


def meshgrid_ij(*tensors):
if pytorch_after(1, 10):
return torch.meshgrid(*tensors, indexing="ij")
return torch.meshgrid(*tensors)
8 changes: 3 additions & 5 deletions monai/transforms/smooth_field/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

import monai
from monai.config.type_definitions import NdarrayOrTensor
from monai.networks.utils import meshgrid_ij
from monai.transforms.transform import Randomizable, RandomizableTransform
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode
from monai.utils.enums import TransformBackends
from monai.utils.module import look_up_option, pytorch_after
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

__all__ = ["SmoothField", "RandSmoothFieldAdjustContrast", "RandSmoothFieldAdjustIntensity", "RandSmoothDeform"]
Expand Down Expand Up @@ -404,10 +405,7 @@ def __init__(
grid_space = spatial_size if spatial_size is not None else self.sfield.field.shape[2:]
grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space]

if pytorch_after(1, 10):
grid = torch.meshgrid(*grid_ranges, indexing="ij")
else:
grid = torch.meshgrid(*grid_ranges)
grid = meshgrid_ij(*grid_ranges)

self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype)

Expand Down
3 changes: 2 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij
from monai.transforms.croppad.array import CenterSpatialCrop, Pad
from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform
from monai.transforms.utils import (
Expand Down Expand Up @@ -2103,7 +2104,7 @@ def __call__(
ranges = ranges - (dim_size - 1.0) / 2.0
all_ranges.append(ranges)

coords = torch.meshgrid(*all_ranges)
coords = meshgrid_ij(*all_ranges)
grid = torch.stack([*coords, torch.ones_like(coords[0])])

return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion tests/test_grid_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parameterized import parameterized

from monai.networks.layers import grid_pull
from monai.networks.utils import meshgrid_ij
from monai.utils import optional_import
from tests.testing_data.cpp_resample_answers import Expected_1D_GP_bwd, Expected_1D_GP_fwd
from tests.utils import skip_if_no_cpp_extension
Expand All @@ -26,7 +27,7 @@

def make_grid(shape, dtype=None, device=None, requires_grad=True):
ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=requires_grad) for s in shape]
grid = torch.stack(torch.meshgrid(*ranges), dim=-1)
grid = torch.stack(meshgrid_ij(*ranges), dim=-1)
return grid[None]


Expand Down