diff --git a/docs/source/data.rst b/docs/source/data.rst index 0de5e0c347..eab4f867af 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -268,6 +268,12 @@ TestTimeAugmentation ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.TestTimeAugmentation +N-Dim Fourier Transform +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.data.fft_utils +.. autofunction:: monai.data.fft_utils.fftn_centered +.. autofunction:: monai.data.fft_utils.ifftn_centered + Meta Object ----------- diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 2164fe5d1b..a7b6d8cdc0 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -233,6 +233,17 @@ Blocks .. autoclass:: DVF2DDF :members: + +N-Dim Fourier Transform +~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.networks.blocks.fft_utils_t +.. autofunction:: monai.networks.blocks.fft_utils_t.fftn_centered_t +.. autofunction:: monai.networks.blocks.fft_utils_t.ifftn_centered_t +.. autofunction:: monai.networks.blocks.fft_utils_t.roll +.. autofunction:: monai.networks.blocks.fft_utils_t.roll_1d +.. autofunction:: monai.networks.blocks.fft_utils_t.fftshift +.. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift + Layers ------ diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index f0f95dbc3e..47e7b6911d 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -621,7 +621,7 @@ def centers_in_boxes(centers: NdarrayOrTensor, boxes: NdarrayOrTensor, eps: floa min_center_to_border: np.ndarray = np.stack(center_to_border, axis=1).min(axis=1) return min_center_to_border > eps # array[bool] - return torch.stack(center_to_border, dim=1).to(COMPUTE_DTYPE).min(dim=1)[0] > eps # Tensor[bool] + return torch.stack(center_to_border, dim=1).to(COMPUTE_DTYPE).min(dim=1)[0] > eps # type: ignore def boxes_center_distance( diff --git a/monai/data/fft_utils.py b/monai/data/fft_utils.py new file mode 100644 index 0000000000..2d38beed5e --- /dev/null +++ b/monai/data/fft_utils.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.config.type_definitions import NdarrayOrTensor +from monai.networks.blocks.fft_utils_t import fftn_centered_t, ifftn_centered_t +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type + + +def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: + """ + Pytorch-based ifft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. This function calls monai.metworks.blocks.fft_utils_t.ifftn_centered_t. + This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift + + Args: + ksp: k-space data that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) + + Returns: + "out" which is the output image (inverse fourier of ksp) + + Example: + + .. code-block:: python + + import torch + ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) + """ + # handle numpy format + ksp_t, *_ = convert_data_type(ksp, torch.Tensor) + + # compute ifftn + out_t = ifftn_centered_t(ksp_t, spatial_dims=spatial_dims, is_complex=is_complex) + + # handle numpy format + out, *_ = convert_to_dst_type(src=out_t, dst=ksp) + return out + + +def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: + """ + Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. This function calls monai.metworks.blocks.fft_utils_t.fftn_centered_t. + This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift + + Args: + im: image that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) + + Returns: + "out" which is the output kspace (fourier of im) + + Example: + + .. code-block:: python + + import torch + im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = fftn_centered(im, spatial_dims=2, is_complex=True) + """ + # handle numpy format + im_t, *_ = convert_data_type(im, torch.Tensor) + + # compute ifftn + out_t = fftn_centered_t(im_t, spatial_dims=spatial_dims, is_complex=is_complex) + + # handle numpy format + out, *_ = convert_to_dst_type(src=out_t, dst=im) + return out diff --git a/monai/networks/blocks/feature_pyramid_network.py b/monai/networks/blocks/feature_pyramid_network.py index 2f7b903a19..2373cfc099 100644 --- a/monai/networks/blocks/feature_pyramid_network.py +++ b/monai/networks/blocks/feature_pyramid_network.py @@ -193,7 +193,7 @@ def __init__( conv_type_: Type[nn.Module] = Conv[Conv.CONV, spatial_dims] for m in self.modules(): if isinstance(m, conv_type_): - nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.kaiming_uniform_(m.weight, a=1) # type: ignore nn.init.constant_(m.bias, 0.0) # type: ignore if extra_blocks is not None: diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py new file mode 100644 index 0000000000..0d6b99d7e1 --- /dev/null +++ b/monai/networks/blocks/fft_utils_t.py @@ -0,0 +1,232 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence + +import torch +from torch import Tensor + +from monai.utils.type_conversion import convert_data_type + + +def roll_1d(x: Tensor, shift: int, shift_dim: int) -> Tensor: + """ + Similar to roll but for only one dim. + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift: the amount of shift along each of shift_dims dimension + shift_dim: the dimension over which the shift is applied + + Returns: + 1d-shifted version of x + + Note: + This function is called when fftshift and ifftshift are not available in the running pytorch version + """ + shift = shift % x.size(shift_dim) + if shift == 0: + return x + + left = x.narrow(shift_dim, 0, x.size(shift_dim) - shift) + right = x.narrow(shift_dim, x.size(shift_dim) - shift, shift) + + return torch.cat((right, left), dim=shift_dim) + + +def roll(x: Tensor, shift: Sequence[int], shift_dims: Sequence[int]) -> Tensor: + """ + Similar to np.roll but applies to PyTorch Tensors + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift: the amount of shift along each of shift_dims dimensions + shift_dims: dimensions over which the shift is applied + + Returns: + shifted version of x + + Note: + This function is called when fftshift and ifftshift are not available in the running pytorch version + """ + if len(shift) != len(shift_dims): + raise ValueError(f"len(shift) != len(shift_dims), got f{len(shift)} and f{len(shift_dims)}.") + for s, d in zip(shift, shift_dims): + x = roll_1d(x, s, d) + return x + + +def fftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor: + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift_dims: dimensions over which the shift is applied + + Returns: + fft-shifted version of x + + Note: + This function is called when fftshift is not available in the running pytorch version + """ + if shift_dims is None: + # for torch.jit.script based on the fastmri repository + shift_dims = [0] * (x.dim()) + for i in range(1, x.dim()): + shift_dims[i] = i + shift = [0] * len(shift_dims) + for i, dim_num in enumerate(shift_dims): + shift[i] = x.shape[dim_num] // 2 + return roll(x, shift, shift_dims) + + +def ifftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor: + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift_dims: dimensions over which the shift is applied + + Returns: + ifft-shifted version of x + + Note: + This function is called when ifftshift is not available in the running pytorch version + """ + if shift_dims is None: + # for torch.jit.script based on the fastmri repository + shift_dims = [0] * (x.dim()) + for i in range(1, x.dim()): + shift_dims[i] = i + shift = [0] * len(shift_dims) + for i, dim_num in enumerate(shift_dims): + shift[i] = (x.shape[dim_num] + 1) // 2 + return roll(x, shift, shift_dims) + + +def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: + """ + Pytorch-based ifft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. + This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift + + Args: + ksp: k-space data that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) + + Returns: + "out" which is the output image (inverse fourier of ksp) + + Example: + + .. code-block:: python + + import torch + ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) + """ + # define spatial dims to perform ifftshift, fftshift, and ifft + shift = tuple(range(-spatial_dims, 0)) + if is_complex: + if ksp.shape[-1] != 2: + raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).") + shift = tuple(range(-spatial_dims - 1, -1)) + dims = tuple(range(-spatial_dims, 0)) + + # apply ifft + if hasattr(torch.fft, "ifftshift"): # ifftshift was added in pytorch 1.8 + x = torch.fft.ifftshift(ksp, dim=shift) + else: + x = ifftshift(ksp, shift) + + if is_complex: + x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) + else: + x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) + + if hasattr(torch.fft, "fftshift"): + out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] + else: + out = convert_data_type(fftshift(x, shift), torch.Tensor)[0] + + return out + + +def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: + """ + Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. + This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift + + Args: + im: image that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) + + Returns: + "out" which is the output kspace (fourier of im) + + Example: + + .. code-block:: python + + import torch + im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = fftn_centered(im, spatial_dims=2, is_complex=True) + """ + # define spatial dims to perform ifftshift, fftshift, and fft + shift = tuple(range(-spatial_dims, 0)) + if is_complex: + if im.shape[-1] != 2: + raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).") + shift = tuple(range(-spatial_dims - 1, -1)) + dims = tuple(range(-spatial_dims, 0)) + + # apply fft + if hasattr(torch.fft, "ifftshift"): # ifftshift was added in pytorch 1.8 + x = torch.fft.ifftshift(im, dim=shift) + else: + x = ifftshift(im, shift) + + if is_complex: + x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) + else: + x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) + + if hasattr(torch.fft, "fftshift"): + out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] + else: + out = convert_data_type(fftshift(x, shift), torch.Tensor)[0] + + return out diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 0d5bafb026..d50505fe84 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -313,10 +313,10 @@ def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: if isinstance(img, torch.Tensor): outputs = list(torch.split(img, 1, self.dim)) else: - outputs = np.split(img, n_out, self.dim) + outputs = np.split(img, n_out, self.dim) # type: ignore if not self.keepdim: outputs = [o.squeeze(self.dim) for o in outputs] - return outputs + return outputs # type: ignore @deprecated(since="0.8", msg_suffix="please use `SplitDim` instead.") diff --git a/tests/test_fft_utils.py b/tests/test_fft_utils.py new file mode 100644 index 0000000000..d5e3a22eaa --- /dev/null +++ b/tests/test_fft_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from parameterized import parameterized + +from monai.data.fft_utils import fftn_centered, ifftn_centered +from tests.utils import TEST_NDARRAYS, assert_allclose + +# +im = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] +res = [ + [[[0.0, 0.0], [0.0, 3.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] +] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append((p(im), p(res))) + +# +TESTS_CONSISTENCY = [] +for p in TEST_NDARRAYS: + TESTS_CONSISTENCY.append(p(im)) + +# +im_complex = [ + [[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]] +] +TESTS_CONSISTENCY_COMPLEX = [] +for p in TEST_NDARRAYS: + TESTS_CONSISTENCY_COMPLEX.append(p(im_complex)) + + +class TestFFT(unittest.TestCase): + @parameterized.expand(TESTS) + def test(self, test_data, res_data): + result = fftn_centered(test_data, spatial_dims=2, is_complex=False) + assert_allclose(result, res_data, type_test=True) + + @parameterized.expand(TESTS_CONSISTENCY) + def test_consistency(self, test_data): + result = fftn_centered(test_data, spatial_dims=2, is_complex=False) + result = ifftn_centered(result, spatial_dims=2, is_complex=True) + result = (result[..., 0] ** 2 + result[..., 1] ** 2) ** 0.5 + assert_allclose(result, test_data, type_test=False) + + @parameterized.expand(TESTS_CONSISTENCY_COMPLEX) + def test_consistency_complex(self, test_data): + result = fftn_centered(test_data, spatial_dims=2) + result = ifftn_centered(result, spatial_dims=2) + assert_allclose(result, test_data, type_test=False) + + +if __name__ == "__main__": + unittest.main()