From e0d20e56fb0630bb1550f59fa95c723d74a9f425 Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Tue, 21 Jun 2022 07:45:19 -0700 Subject: [PATCH 01/17] mri utils added Signed-off-by: mersad95zd --- monai/transforms/mri_utils.py | 235 ++++++++++++++++++++++++++++++++++ 1 file changed, 235 insertions(+) create mode 100644 monai/transforms/mri_utils.py diff --git a/monai/transforms/mri_utils.py b/monai/transforms/mri_utils.py new file mode 100644 index 0000000000..0bca2ae5a5 --- /dev/null +++ b/monai/transforms/mri_utils.py @@ -0,0 +1,235 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + + +def to_tensor(data): + """ + Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts + are stacked along the last dimension. + inputs: + data (np.array): Input numpy array + outputs: + torch.Tensor: PyTorch version of data + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + return torch.from_numpy(data) + + +def complex_abs(x): + """ + Compute the absolute value of a complex array. + inputs: + x (np.array): Input numpy array with 2 channels in the last + dimension representing real and imaginary parts. + outputs: + np.array: Absolute value along the last dimention + """ + assert x.shape[-1] == 2 + return np.sqrt(x[..., 0] ** 2 + x[..., 1] ** 2) + + +def roll(x, shift, dim): + """ + Similar to np.roll but applies to PyTorch Tensors + """ + if isinstance(shift, (tuple, list)): + assert len(shift) == len(dim) + for s, d in zip(shift, dim): + x = roll(x, s, d) + return x + shift = shift % x.size(dim) + if shift == 0: + return x + left = x.narrow(dim, 0, x.size(dim) - shift) + right = x.narrow(dim, x.size(dim) - shift, shift) + return torch.cat((right, left), dim=dim) + + +def fftshift(x, dim=None): + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [dim // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = x.shape[dim] // 2 + else: + shift = [x.shape[i] // 2 for i in dim] + return roll(x, shift, dim) + + +def ifftshift(x, dim=None): + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + """ + if dim is None: + dim = tuple(range(x.dim())) + shift = [(dim + 1) // 2 for dim in x.shape] + elif isinstance(dim, int): + shift = (x.shape[dim] + 1) // 2 + else: + shift = [(x.shape[i] + 1) // 2 for i in dim] + return roll(x, shift, dim) + + +def fft(input, signal_ndim, normalized=False): + """ + This function is called from the fft2 function below + """ + if signal_ndim < 1 or signal_ndim > 3: + print("Signal ndim out of range, was", signal_ndim, "but expected a value between 1 and 3, inclusive") + return + + dims = -1 + if signal_ndim == 2: + dims = (-2, -1) + if signal_ndim == 3: + dims = (-3, -2, -1) + + norm = "backward" + if normalized: + norm = "ortho" + + return torch.view_as_real(torch.fft.fftn(torch.view_as_complex(input), dim=dims, norm=norm)) + + +def ifft(input, signal_ndim, normalized=False): + """ + This function is called from the ifft2 function below + """ + if signal_ndim < 1 or signal_ndim > 3: + print("Signal ndim out of range, was", signal_ndim, "but expected a value between 1 and 3, inclusive") + return + + dims = -1 + if signal_ndim == 2: + dims = (-2, -1) + if signal_ndim == 3: + dims = (-3, -2, -1) + + norm = "backward" + if normalized: + norm = "ortho" + + return torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(input), dim=dims, norm=norm)) + + +def fft2(data): + """ + ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri + Apply centered 2 dimensional Fast Fourier Transform. It calls the fft function above to make it compatible with the latest version of pytorch. + inputs: + data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions + -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are + assumed to be batch dimensions. + outputs: + torch.Tensor: The FFT of the input. + """ + assert data.size(-1) == 2 + data = ifftshift(data, dim=(-3, -2)) + data = fft(data, 2, normalized=True) + data = fftshift(data, dim=(-3, -2)) + return data + + +def ifft2(data): + """ + ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri + Apply centered 2-dimensional Inverse Fast Fourier Transform. It calls the ifft function above to make it compatible with the latest version of pytorch. + inputs: + data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions + -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are + assumed to be batch dimensions. + outputs: + torch.Tensor: The IFFT of the input. + """ + assert data.size(-1) == 2 + data = ifftshift(data, dim=(-3, -2)) + data = ifft(data, 2, normalized=True) + data = fftshift(data, dim=(-3, -2)) + return data + + +def apply_mask(data, mask_func=None, mask=None, seed=None): + """ + Subsample given k-space by multiplying with a mask. + inputs: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + outputs: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + if mask is None: + mask = mask_func(shape, seed) + return data * mask, mask + + +def center_crop(data, shape): + """ + Apply a center crop to the input real image or batch of real images. + inputs: + data (torch.Tensor): The input tensor to be center cropped. It should have at + least 2 dimensions and the cropping is applied along the last two dimensions. + shape (int, int): The output shape. The shape should be smaller than the + corresponding dimensions of data. + outputs: + torch.Tensor: The center cropped image + """ + assert 0 < shape[0] <= data.shape[-2] + assert 0 < shape[1] <= data.shape[-1] + w_from = (data.shape[-2] - shape[0]) // 2 + h_from = (data.shape[-1] - shape[1]) // 2 + w_to = w_from + shape[0] + h_to = h_from + shape[1] + return data[..., w_from:w_to, h_from:h_to] + + +def normalize(data, mean, stddev, eps=0.0): + """ + Normalize (standardize in this case) the given tensor using: + (data - mean) / (stddev + eps) + inputss: + data (torch.Tensor): Input data to be normalized + mean (float): Mean value + stddev (float): Standard deviation + eps (float): Added to stddev to prevent dividing by zero + outputs: + torch.Tensor: Normalized tensor + """ + return (data - mean) / (stddev + eps) + + +def normalize_instance(data, eps=0.0): + """ + Normalize (standardize in this case) the given tensor using: + (data - mean) / (stddev + eps) + where mean and stddev are computed from the data itself. + inputs: + data (torch.Tensor): Input data to be normalized + eps (float): Added to stddev to prevent dividing by zero + outputs: + torch.Tensor: Normalized tensor + """ + mean = data.mean() + std = data.std() + return normalize(data, mean, std, eps), mean, std From aecedaaa679ea068dee934411d188354366d1e34 Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Wed, 22 Jun 2022 15:16:27 -0700 Subject: [PATCH 02/17] fft_utils with its unit test added Signed-off-by: mersad95zd --- monai/apps/reconstruction/fft_utils.py | 86 +++++++ monai/apps/reconstruction/mri_utils.py | 246 ++++++++++++++++++++ monai/apps/reconstruction/test_fft_utils.py | 61 +++++ monai/transforms/mri_utils.py | 235 ------------------- 4 files changed, 393 insertions(+), 235 deletions(-) create mode 100644 monai/apps/reconstruction/fft_utils.py create mode 100644 monai/apps/reconstruction/mri_utils.py create mode 100644 monai/apps/reconstruction/test_fft_utils.py delete mode 100644 monai/transforms/mri_utils.py diff --git a/monai/apps/reconstruction/fft_utils.py b/monai/apps/reconstruction/fft_utils.py new file mode 100644 index 0000000000..bc48856b5f --- /dev/null +++ b/monai/apps/reconstruction/fft_utils.py @@ -0,0 +1,86 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from numpy import ndarray +from torch import Tensor + +from monai.config.type_definitions import NdarrayOrTensor + + +def ifftn(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Tensor: + """ + Pytorch-based ifft for spatial_dims-dim signals. + inputs: + ksp: k-space data + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) + """ + # handle numpy format + isnp = False + if isinstance(ksp, ndarray): + ksp = torch.from_numpy(ksp) + isnp = True + + # define spatial dims to perform ifftshift, fftshift, and ifft + shift = tuple(range(-spatial_dims, 0)) + if is_complex: + assert ksp.shape[-1] == 2 + shift = tuple(range(-spatial_dims - 1, -1)) + dims = tuple(range(-spatial_dims, 0)) + + # apply ifft + x = torch.fft.ifftshift(ksp, dim=shift) + if is_complex: + x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) + else: + x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) + out = torch.fft.fftshift(x, dim=shift) + + # handle numpy format + if isnp: + out = out.numpy() + return out + + +def fftn(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Tensor: + """ + Pytorch-based fft for spatial_dims-dim signals. + inputs: + im: image + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) + """ + # handle numpy format + isnp = False + if isinstance(im, ndarray): + im = torch.from_numpy(im) + isnp = True + + # define spatial dims to perform ifftshift, fftshift, and fft + shift = tuple(range(-spatial_dims, 0)) + if is_complex: + assert im.shape[-1] == 2 + shift = tuple(range(-spatial_dims - 1, -1)) + dims = tuple(range(-spatial_dims, 0)) + + # apply fft + x = torch.fft.ifftshift(im, dim=shift) + if is_complex: + x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) + else: + x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) + out = torch.fft.fftshift(x, dim=shift) + + # handle numpy format + if isnp: + out = out.numpy() + return out diff --git a/monai/apps/reconstruction/mri_utils.py b/monai/apps/reconstruction/mri_utils.py new file mode 100644 index 0000000000..c6a0060afe --- /dev/null +++ b/monai/apps/reconstruction/mri_utils.py @@ -0,0 +1,246 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import numpy as np +import torch +from numpy import ndarray +from torch import Tensor + +import monai + + +def convert_to_tensor_complex(data: ndarray) -> Tensor: + """ + Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts + are stacked along the last dimension. + inputs: + data (np.array): Input numpy array + outputs: + torch.Tensor: PyTorch version of data + """ + if np.iscomplexobj(data): + data = np.stack((data.real, data.imag), axis=-1) + return monai.utils.type_conversion.convert_to_tensor(data) + + +def complex_abs(x: ndarray) -> ndarray: + """ + Compute the absolute value of a complex array. + inputs: + x (np.array): Input numpy array with 2 channels in the last + dimension representing real and imaginary parts. + outputs: + np.array: Absolute value along the last dimention + """ + assert x.shape[-1] == 2 + return np.sqrt(x[..., 0] ** 2 + x[..., 1] ** 2) + + +# mask functions +class MaskFunc: + def __init__(self, center_fractions: list, accelerations: list) -> None: + """ + inputs: + center_fractions (List[float]): Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly + each time. + + accelerations (List[int]): Amount of under-sampling. This should have the same length + as center_fractions. If multiple values are provided, then one of these is chosen + uniformly each time. + """ + if len(center_fractions) != len(accelerations): + raise ValueError("Number of center fractions should match number of accelerations") + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random.RandomState() + + def choose_acceleration(self): + choice = self.rng.randint(0, len(self.accelerations)) + center_fraction = self.center_fractions[choice] + acceleration = self.accelerations[choice] + return center_fraction, acceleration + + +def create_mask_for_mask_type(mask_type_str: str, center_fractions: list, accelerations: list) -> MaskFunc: + """ + Create an under-sampling mask generator + inputs: + mask_type_str (string): denotes the mask type ('random','equispaced') + center_fractions (List[float]): Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly each time. + accelerations (List[int]): Amount of under-sampling. This should have the same length as center_fractions. + If multiple values are provided, then one of these is chosen uniformly each time. + outputs: + callable mask function + """ + if mask_type_str == "random": + return RandomMaskFunc(center_fractions, accelerations) + elif mask_type_str == "equispaced": + return EquispacedMaskFunc(center_fractions, accelerations) + else: + raise Exception(f"{mask_type_str} not supported") + + +class RandomMaskFunc(MaskFunc): + """ + ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri + RandomMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the k-space data has N + columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to + low-frequencies + 2. The other columns are selected uniformly at random with a probability equal to: + prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). + This ensures that the expected number of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the RandomMaskFunc object is + called. + + For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there + is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% + probability that 8-fold acceleration with 4% center fraction is selected. + """ + + def __init__(self, center_fractions: list, accelerations: list) -> None: + """ + inputs: + center_fractions (List[float]): Fraction of low-frequency columns to be retained. + If multiple values are provided, then one of these numbers is chosen uniformly + each time. + + accelerations (List[int]): Amount of under-sampling. This should have the same length + as center_fractions. If multiple values are provided, then one of these is chosen + uniformly each time. An acceleration of 4 retains 25% of the columns, but they may + not be spaced evenly. + """ + if len(center_fractions) != len(accelerations): + raise ValueError("Number of center fractions should match number of accelerations") + + self.center_fractions = center_fractions + self.accelerations = accelerations + self.rng = np.random.RandomState() + + def __call__(self, spatial_size: tuple, seed: Optional[int] = None) -> Tensor: + """ + inputs: + shape (iterable[int]): The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same shape. + outputs: + torch.Tensor: A mask of the specified shape. + """ + if len(spatial_size) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + self.rng.seed(seed) + num_cols = spatial_size[-2] + center_fraction, acceleration = self.choose_acceleration() + + # Create the mask + num_low_freqs = int(round(num_cols * center_fraction)) + prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) + mask = self.rng.uniform(size=num_cols) < prob + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # Reshape the mask + mask_shape = [1 for _ in spatial_size] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +class EquispacedMaskFunc(MaskFunc): + """ + ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri + EquispacedMaskFunc creates a sub-sampling mask of a given shape. + + The mask selects a subset of columns from the input k-space data. If the k-space data has N + columns, the mask picks out: + 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to + low-frequencies + 2. The other columns are selected with equal spacing at a proportion that reaches the + desired acceleration rate taking into consideration the number of low frequencies. This + ensures that the expected number of columns selected is equal to (N / acceleration) + + It is possible to use multiple center_fractions and accelerations, in which case one possible + (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc + object is called. + """ + + def __call__(self, spatial_size: tuple, seed: Optional[int] = None) -> Tensor: + """ + inputs: + shape (iterable[int]): The shape of the mask to be created. The shape should have + at least 3 dimensions. Samples are drawn along the second last dimension. + seed (int, optional): Seed for the random number generator. Setting the seed + ensures the same mask is generated each time for the same shape. + outputs: + torch.Tensor: A mask of the specified shape. + """ + if len(spatial_size) < 3: + raise ValueError("Shape should have 3 or more dimensions") + + self.rng.seed(seed) + center_fraction, acceleration = self.choose_acceleration() + num_cols = spatial_size[-2] + num_low_freqs = int(round(num_cols * center_fraction)) + + # Create the mask + mask = np.zeros(num_cols, dtype=np.float32) + pad = (num_cols - num_low_freqs + 1) // 2 + mask[pad : pad + num_low_freqs] = True + + # Determine acceleration rate by adjusting for the number of low frequencies + adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) + offset = self.rng.randint(0, round(adjusted_accel)) + + accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) + accel_samples = np.around(accel_samples).astype(np.uint) + mask[accel_samples] = True + + # Reshape the mask + mask_shape = [1 for _ in spatial_size] + mask_shape[-2] = num_cols + mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) + + return mask + + +def apply_mask( + data: Tensor, mask_func: Optional[MaskFunc] = None, mask: Optional[torch.Tensor] = None, seed: Optional[int] = None +) -> Tensor: + """ + Subsample given k-space by multiplying with a mask. + inputs: + data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where + dimensions -3 and -2 are the spatial dimensions, and the final dimension has size + 2 (for complex values). + mask_func (callable): A function that takes a shape (tuple of ints) and a random + number seed and returns a mask. + seed (int or 1-d array_like, optional): Seed for the random number generator. + outputs: + (tuple): tuple containing: + masked data (torch.Tensor): Subsampled k-space data + mask (torch.Tensor): The generated mask + """ + shape = np.array(data.shape) + shape[:-3] = 1 + if mask is None: + mask = mask_func(shape, seed) + return data * mask, mask diff --git a/monai/apps/reconstruction/test_fft_utils.py b/monai/apps/reconstruction/test_fft_utils.py new file mode 100644 index 0000000000..d719b9ba3d --- /dev/null +++ b/monai/apps/reconstruction/test_fft_utils.py @@ -0,0 +1,61 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from fft_utils import fftn, ifftn +from parameterized import parameterized + +from tests.utils import TEST_NDARRAYS, assert_allclose + +# +im = [[[1, 1, 1], [1, 1, 1], [1, 1, 1]]] +res = [ + [[[0.0, 0.0], [0.0, 3.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] +] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append((p(im), p(res))) + +# +TESTS_CONSISTENCY = [] +for p in TEST_NDARRAYS: + TESTS_CONSISTENCY.append(p(im)) + +# +im_complex = [ + [[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]] +] +TESTS_CONSISTENCY_COMPLEX = [] +for p in TEST_NDARRAYS: + TESTS_CONSISTENCY_COMPLEX.append(p(im_complex)) + + +class TestFFT(unittest.TestCase): + @parameterized.expand(TESTS) + def test(self, test_data, res_data): + result = fftn(test_data, spatial_dims=2, is_complex=False) + assert_allclose(result, res_data, type_test=True) + + @parameterized.expand(TESTS_CONSISTENCY) + def test_consistency(self, test_data): + result = ifftn(fftn(test_data, spatial_dims=2, is_complex=False), spatial_dims=2, is_complex=True) + result = (result[..., 0] ** 2 + result[..., 1] ** 2) ** 0.5 + assert_allclose(result, test_data, type_test=False) + + @parameterized.expand(TESTS_CONSISTENCY_COMPLEX) + def test_consistency_complex(self, test_data): + result = ifftn(fftn(test_data, spatial_dims=2), spatial_dims=2) + assert_allclose(result, test_data, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/monai/transforms/mri_utils.py b/monai/transforms/mri_utils.py deleted file mode 100644 index 0bca2ae5a5..0000000000 --- a/monai/transforms/mri_utils.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import torch - - -def to_tensor(data): - """ - Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts - are stacked along the last dimension. - inputs: - data (np.array): Input numpy array - outputs: - torch.Tensor: PyTorch version of data - """ - if np.iscomplexobj(data): - data = np.stack((data.real, data.imag), axis=-1) - return torch.from_numpy(data) - - -def complex_abs(x): - """ - Compute the absolute value of a complex array. - inputs: - x (np.array): Input numpy array with 2 channels in the last - dimension representing real and imaginary parts. - outputs: - np.array: Absolute value along the last dimention - """ - assert x.shape[-1] == 2 - return np.sqrt(x[..., 0] ** 2 + x[..., 1] ** 2) - - -def roll(x, shift, dim): - """ - Similar to np.roll but applies to PyTorch Tensors - """ - if isinstance(shift, (tuple, list)): - assert len(shift) == len(dim) - for s, d in zip(shift, dim): - x = roll(x, s, d) - return x - shift = shift % x.size(dim) - if shift == 0: - return x - left = x.narrow(dim, 0, x.size(dim) - shift) - right = x.narrow(dim, x.size(dim) - shift, shift) - return torch.cat((right, left), dim=dim) - - -def fftshift(x, dim=None): - """ - Similar to np.fft.fftshift but applies to PyTorch Tensors - """ - if dim is None: - dim = tuple(range(x.dim())) - shift = [dim // 2 for dim in x.shape] - elif isinstance(dim, int): - shift = x.shape[dim] // 2 - else: - shift = [x.shape[i] // 2 for i in dim] - return roll(x, shift, dim) - - -def ifftshift(x, dim=None): - """ - Similar to np.fft.ifftshift but applies to PyTorch Tensors - """ - if dim is None: - dim = tuple(range(x.dim())) - shift = [(dim + 1) // 2 for dim in x.shape] - elif isinstance(dim, int): - shift = (x.shape[dim] + 1) // 2 - else: - shift = [(x.shape[i] + 1) // 2 for i in dim] - return roll(x, shift, dim) - - -def fft(input, signal_ndim, normalized=False): - """ - This function is called from the fft2 function below - """ - if signal_ndim < 1 or signal_ndim > 3: - print("Signal ndim out of range, was", signal_ndim, "but expected a value between 1 and 3, inclusive") - return - - dims = -1 - if signal_ndim == 2: - dims = (-2, -1) - if signal_ndim == 3: - dims = (-3, -2, -1) - - norm = "backward" - if normalized: - norm = "ortho" - - return torch.view_as_real(torch.fft.fftn(torch.view_as_complex(input), dim=dims, norm=norm)) - - -def ifft(input, signal_ndim, normalized=False): - """ - This function is called from the ifft2 function below - """ - if signal_ndim < 1 or signal_ndim > 3: - print("Signal ndim out of range, was", signal_ndim, "but expected a value between 1 and 3, inclusive") - return - - dims = -1 - if signal_ndim == 2: - dims = (-2, -1) - if signal_ndim == 3: - dims = (-3, -2, -1) - - norm = "backward" - if normalized: - norm = "ortho" - - return torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(input), dim=dims, norm=norm)) - - -def fft2(data): - """ - ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri - Apply centered 2 dimensional Fast Fourier Transform. It calls the fft function above to make it compatible with the latest version of pytorch. - inputs: - data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions - -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are - assumed to be batch dimensions. - outputs: - torch.Tensor: The FFT of the input. - """ - assert data.size(-1) == 2 - data = ifftshift(data, dim=(-3, -2)) - data = fft(data, 2, normalized=True) - data = fftshift(data, dim=(-3, -2)) - return data - - -def ifft2(data): - """ - ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri - Apply centered 2-dimensional Inverse Fast Fourier Transform. It calls the ifft function above to make it compatible with the latest version of pytorch. - inputs: - data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions - -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are - assumed to be batch dimensions. - outputs: - torch.Tensor: The IFFT of the input. - """ - assert data.size(-1) == 2 - data = ifftshift(data, dim=(-3, -2)) - data = ifft(data, 2, normalized=True) - data = fftshift(data, dim=(-3, -2)) - return data - - -def apply_mask(data, mask_func=None, mask=None, seed=None): - """ - Subsample given k-space by multiplying with a mask. - inputs: - data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where - dimensions -3 and -2 are the spatial dimensions, and the final dimension has size - 2 (for complex values). - mask_func (callable): A function that takes a shape (tuple of ints) and a random - number seed and returns a mask. - seed (int or 1-d array_like, optional): Seed for the random number generator. - outputs: - (tuple): tuple containing: - masked data (torch.Tensor): Subsampled k-space data - mask (torch.Tensor): The generated mask - """ - shape = np.array(data.shape) - shape[:-3] = 1 - if mask is None: - mask = mask_func(shape, seed) - return data * mask, mask - - -def center_crop(data, shape): - """ - Apply a center crop to the input real image or batch of real images. - inputs: - data (torch.Tensor): The input tensor to be center cropped. It should have at - least 2 dimensions and the cropping is applied along the last two dimensions. - shape (int, int): The output shape. The shape should be smaller than the - corresponding dimensions of data. - outputs: - torch.Tensor: The center cropped image - """ - assert 0 < shape[0] <= data.shape[-2] - assert 0 < shape[1] <= data.shape[-1] - w_from = (data.shape[-2] - shape[0]) // 2 - h_from = (data.shape[-1] - shape[1]) // 2 - w_to = w_from + shape[0] - h_to = h_from + shape[1] - return data[..., w_from:w_to, h_from:h_to] - - -def normalize(data, mean, stddev, eps=0.0): - """ - Normalize (standardize in this case) the given tensor using: - (data - mean) / (stddev + eps) - inputss: - data (torch.Tensor): Input data to be normalized - mean (float): Mean value - stddev (float): Standard deviation - eps (float): Added to stddev to prevent dividing by zero - outputs: - torch.Tensor: Normalized tensor - """ - return (data - mean) / (stddev + eps) - - -def normalize_instance(data, eps=0.0): - """ - Normalize (standardize in this case) the given tensor using: - (data - mean) / (stddev + eps) - where mean and stddev are computed from the data itself. - inputs: - data (torch.Tensor): Input data to be normalized - eps (float): Added to stddev to prevent dividing by zero - outputs: - torch.Tensor: Normalized tensor - """ - mean = data.mean() - std = data.std() - return normalize(data, mean, std, eps), mean, std From 1e0e39c17ab4aa76ac6821f61ff9ca0ccc1ccff2 Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Thu, 23 Jun 2022 01:10:52 -0700 Subject: [PATCH 03/17] fft_utils updated with monai data converter Signed-off-by: mersad95zd --- monai/apps/reconstruction/fft_utils.py | 33 ++++++++++++++------- monai/apps/reconstruction/test_fft_utils.py | 5 ++-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/monai/apps/reconstruction/fft_utils.py b/monai/apps/reconstruction/fft_utils.py index bc48856b5f..30ac3ebbf8 100644 --- a/monai/apps/reconstruction/fft_utils.py +++ b/monai/apps/reconstruction/fft_utils.py @@ -14,21 +14,26 @@ from torch import Tensor from monai.config.type_definitions import NdarrayOrTensor +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type def ifftn(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Tensor: """ Pytorch-based ifft for spatial_dims-dim signals. - inputs: + Args: ksp: k-space data spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) + Returns: + out: output image (inverse fourier of ksp) """ # handle numpy format isnp = False if isinstance(ksp, ndarray): - ksp = torch.from_numpy(ksp) + ksp_t, *_ = convert_data_type(ksp, torch.Tensor) isnp = True + else: + ksp_t = ksp.clone() # define spatial dims to perform ifftshift, fftshift, and ifft shift = tuple(range(-spatial_dims, 0)) @@ -38,32 +43,38 @@ def ifftn(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> T dims = tuple(range(-spatial_dims, 0)) # apply ifft - x = torch.fft.ifftshift(ksp, dim=shift) + x = torch.fft.ifftshift(ksp_t, dim=shift) if is_complex: x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) - out = torch.fft.fftshift(x, dim=shift) + out_t = torch.fft.fftshift(x, dim=shift) # handle numpy format if isnp: - out = out.numpy() + out, *_ = convert_to_dst_type(src=out_t, dst=ksp) + else: + out = out_t.clone() return out def fftn(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Tensor: """ Pytorch-based fft for spatial_dims-dim signals. - inputs: + Args: im: image spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) + Returns: + out: output kspace (fourier of im) """ # handle numpy format isnp = False if isinstance(im, ndarray): - im = torch.from_numpy(im) + im_t, *_ = convert_data_type(im, torch.Tensor) isnp = True + else: + im_t = im.clone() # define spatial dims to perform ifftshift, fftshift, and fft shift = tuple(range(-spatial_dims, 0)) @@ -73,14 +84,16 @@ def fftn(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Ten dims = tuple(range(-spatial_dims, 0)) # apply fft - x = torch.fft.ifftshift(im, dim=shift) + x = torch.fft.ifftshift(im_t, dim=shift) if is_complex: x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - out = torch.fft.fftshift(x, dim=shift) + out_t = torch.fft.fftshift(x, dim=shift) # handle numpy format if isnp: - out = out.numpy() + out, *_ = convert_to_dst_type(src=out_t, dst=im) + else: + out = out_t.clone() return out diff --git a/monai/apps/reconstruction/test_fft_utils.py b/monai/apps/reconstruction/test_fft_utils.py index d719b9ba3d..4a1b3a367c 100644 --- a/monai/apps/reconstruction/test_fft_utils.py +++ b/monai/apps/reconstruction/test_fft_utils.py @@ -17,7 +17,7 @@ from tests.utils import TEST_NDARRAYS, assert_allclose # -im = [[[1, 1, 1], [1, 1, 1], [1, 1, 1]]] +im = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] res = [ [[[0.0, 0.0], [0.0, 3.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]] ] @@ -47,7 +47,8 @@ def test(self, test_data, res_data): @parameterized.expand(TESTS_CONSISTENCY) def test_consistency(self, test_data): - result = ifftn(fftn(test_data, spatial_dims=2, is_complex=False), spatial_dims=2, is_complex=True) + result = fftn(test_data, spatial_dims=2, is_complex=False) + result = ifftn(result, spatial_dims=2, is_complex=True) result = (result[..., 0] ** 2 + result[..., 1] ** 2) ** 0.5 assert_allclose(result, test_data, type_test=False) From a452ddc0b8599d000de846092f4ff94f2a63517c Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Thu, 23 Jun 2022 12:38:53 -0700 Subject: [PATCH 04/17] updated fft_util's docstring Signed-off-by: mersad95zd --- docs/source/apps.rst | 8 + monai/apps/reconstruction/fft_utils.py | 64 ++--- monai/apps/reconstruction/mri_utils.py | 246 -------------------- monai/apps/reconstruction/test_fft_utils.py | 11 +- 4 files changed, 52 insertions(+), 277 deletions(-) delete mode 100644 monai/apps/reconstruction/mri_utils.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 3bb85c9296..6ae15e6f64 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -184,3 +184,11 @@ Applications :members: .. automodule:: monai.apps.detection.metrics.matching :members: + + + +"Reconstruction" + +.. automodule:: monai.apps.reconstruction.fft_utils +.. autofunction:: monai.apps.reconstruction.fft_utils.fftn_centered +.. autofunction:: monai.apps.reconstruction.fft_utils.ifftn_centered diff --git a/monai/apps/reconstruction/fft_utils.py b/monai/apps/reconstruction/fft_utils.py index 30ac3ebbf8..3458da57ea 100644 --- a/monai/apps/reconstruction/fft_utils.py +++ b/monai/apps/reconstruction/fft_utils.py @@ -10,30 +10,38 @@ # limitations under the License. import torch -from numpy import ndarray -from torch import Tensor from monai.config.type_definitions import NdarrayOrTensor from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -def ifftn(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Tensor: +def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: """ Pytorch-based ifft for spatial_dims-dim signals. + This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifft.shift + Args: ksp: k-space data spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) + Returns: - out: output image (inverse fourier of ksp) + Union[ndarray,Tensor] "out" which is the output image (inverse fourier of ksp) + + Example: + + .. code-block:: python + + import torch + ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) """ # handle numpy format - isnp = False - if isinstance(ksp, ndarray): - ksp_t, *_ = convert_data_type(ksp, torch.Tensor) - isnp = True - else: - ksp_t = ksp.clone() + ksp_t, *_ = convert_data_type(ksp, torch.Tensor) # define spatial dims to perform ifftshift, fftshift, and ifft shift = tuple(range(-spatial_dims, 0)) @@ -51,30 +59,37 @@ def ifftn(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> T out_t = torch.fft.fftshift(x, dim=shift) # handle numpy format - if isnp: - out, *_ = convert_to_dst_type(src=out_t, dst=ksp) - else: - out = out_t.clone() + out, *_ = convert_to_dst_type(src=out_t, dst=ksp) return out -def fftn(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Tensor: +def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: """ Pytorch-based fft for spatial_dims-dim signals. + This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifft.shift + Args: im: image spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) + Returns: - out: output kspace (fourier of im) + Union[ndarray,Tensor] "out" which is the output kspace (fourier of im) + + Example: + + .. code-block:: python + + import torch + im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = fftn_centered(im, spatial_dims=2, is_complex=True) """ # handle numpy format - isnp = False - if isinstance(im, ndarray): - im_t, *_ = convert_data_type(im, torch.Tensor) - isnp = True - else: - im_t = im.clone() + im_t, *_ = convert_data_type(im, torch.Tensor) # define spatial dims to perform ifftshift, fftshift, and fft shift = tuple(range(-spatial_dims, 0)) @@ -92,8 +107,5 @@ def fftn(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> Ten out_t = torch.fft.fftshift(x, dim=shift) # handle numpy format - if isnp: - out, *_ = convert_to_dst_type(src=out_t, dst=im) - else: - out = out_t.clone() + out, *_ = convert_to_dst_type(src=out_t, dst=im) return out diff --git a/monai/apps/reconstruction/mri_utils.py b/monai/apps/reconstruction/mri_utils.py deleted file mode 100644 index c6a0060afe..0000000000 --- a/monai/apps/reconstruction/mri_utils.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import numpy as np -import torch -from numpy import ndarray -from torch import Tensor - -import monai - - -def convert_to_tensor_complex(data: ndarray) -> Tensor: - """ - Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts - are stacked along the last dimension. - inputs: - data (np.array): Input numpy array - outputs: - torch.Tensor: PyTorch version of data - """ - if np.iscomplexobj(data): - data = np.stack((data.real, data.imag), axis=-1) - return monai.utils.type_conversion.convert_to_tensor(data) - - -def complex_abs(x: ndarray) -> ndarray: - """ - Compute the absolute value of a complex array. - inputs: - x (np.array): Input numpy array with 2 channels in the last - dimension representing real and imaginary parts. - outputs: - np.array: Absolute value along the last dimention - """ - assert x.shape[-1] == 2 - return np.sqrt(x[..., 0] ** 2 + x[..., 1] ** 2) - - -# mask functions -class MaskFunc: - def __init__(self, center_fractions: list, accelerations: list) -> None: - """ - inputs: - center_fractions (List[float]): Fraction of low-frequency columns to be retained. - If multiple values are provided, then one of these numbers is chosen uniformly - each time. - - accelerations (List[int]): Amount of under-sampling. This should have the same length - as center_fractions. If multiple values are provided, then one of these is chosen - uniformly each time. - """ - if len(center_fractions) != len(accelerations): - raise ValueError("Number of center fractions should match number of accelerations") - - self.center_fractions = center_fractions - self.accelerations = accelerations - self.rng = np.random.RandomState() - - def choose_acceleration(self): - choice = self.rng.randint(0, len(self.accelerations)) - center_fraction = self.center_fractions[choice] - acceleration = self.accelerations[choice] - return center_fraction, acceleration - - -def create_mask_for_mask_type(mask_type_str: str, center_fractions: list, accelerations: list) -> MaskFunc: - """ - Create an under-sampling mask generator - inputs: - mask_type_str (string): denotes the mask type ('random','equispaced') - center_fractions (List[float]): Fraction of low-frequency columns to be retained. - If multiple values are provided, then one of these numbers is chosen uniformly each time. - accelerations (List[int]): Amount of under-sampling. This should have the same length as center_fractions. - If multiple values are provided, then one of these is chosen uniformly each time. - outputs: - callable mask function - """ - if mask_type_str == "random": - return RandomMaskFunc(center_fractions, accelerations) - elif mask_type_str == "equispaced": - return EquispacedMaskFunc(center_fractions, accelerations) - else: - raise Exception(f"{mask_type_str} not supported") - - -class RandomMaskFunc(MaskFunc): - """ - ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri - RandomMaskFunc creates a sub-sampling mask of a given shape. - - The mask selects a subset of columns from the input k-space data. If the k-space data has N - columns, the mask picks out: - 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to - low-frequencies - 2. The other columns are selected uniformly at random with a probability equal to: - prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). - This ensures that the expected number of columns selected is equal to (N / acceleration) - - It is possible to use multiple center_fractions and accelerations, in which case one possible - (center_fraction, acceleration) is chosen uniformly at random each time the RandomMaskFunc object is - called. - - For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there - is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% - probability that 8-fold acceleration with 4% center fraction is selected. - """ - - def __init__(self, center_fractions: list, accelerations: list) -> None: - """ - inputs: - center_fractions (List[float]): Fraction of low-frequency columns to be retained. - If multiple values are provided, then one of these numbers is chosen uniformly - each time. - - accelerations (List[int]): Amount of under-sampling. This should have the same length - as center_fractions. If multiple values are provided, then one of these is chosen - uniformly each time. An acceleration of 4 retains 25% of the columns, but they may - not be spaced evenly. - """ - if len(center_fractions) != len(accelerations): - raise ValueError("Number of center fractions should match number of accelerations") - - self.center_fractions = center_fractions - self.accelerations = accelerations - self.rng = np.random.RandomState() - - def __call__(self, spatial_size: tuple, seed: Optional[int] = None) -> Tensor: - """ - inputs: - shape (iterable[int]): The shape of the mask to be created. The shape should have - at least 3 dimensions. Samples are drawn along the second last dimension. - seed (int, optional): Seed for the random number generator. Setting the seed - ensures the same mask is generated each time for the same shape. - outputs: - torch.Tensor: A mask of the specified shape. - """ - if len(spatial_size) < 3: - raise ValueError("Shape should have 3 or more dimensions") - - self.rng.seed(seed) - num_cols = spatial_size[-2] - center_fraction, acceleration = self.choose_acceleration() - - # Create the mask - num_low_freqs = int(round(num_cols * center_fraction)) - prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) - mask = self.rng.uniform(size=num_cols) < prob - pad = (num_cols - num_low_freqs + 1) // 2 - mask[pad : pad + num_low_freqs] = True - - # Reshape the mask - mask_shape = [1 for _ in spatial_size] - mask_shape[-2] = num_cols - mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) - - return mask - - -class EquispacedMaskFunc(MaskFunc): - """ - ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri - EquispacedMaskFunc creates a sub-sampling mask of a given shape. - - The mask selects a subset of columns from the input k-space data. If the k-space data has N - columns, the mask picks out: - 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to - low-frequencies - 2. The other columns are selected with equal spacing at a proportion that reaches the - desired acceleration rate taking into consideration the number of low frequencies. This - ensures that the expected number of columns selected is equal to (N / acceleration) - - It is possible to use multiple center_fractions and accelerations, in which case one possible - (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc - object is called. - """ - - def __call__(self, spatial_size: tuple, seed: Optional[int] = None) -> Tensor: - """ - inputs: - shape (iterable[int]): The shape of the mask to be created. The shape should have - at least 3 dimensions. Samples are drawn along the second last dimension. - seed (int, optional): Seed for the random number generator. Setting the seed - ensures the same mask is generated each time for the same shape. - outputs: - torch.Tensor: A mask of the specified shape. - """ - if len(spatial_size) < 3: - raise ValueError("Shape should have 3 or more dimensions") - - self.rng.seed(seed) - center_fraction, acceleration = self.choose_acceleration() - num_cols = spatial_size[-2] - num_low_freqs = int(round(num_cols * center_fraction)) - - # Create the mask - mask = np.zeros(num_cols, dtype=np.float32) - pad = (num_cols - num_low_freqs + 1) // 2 - mask[pad : pad + num_low_freqs] = True - - # Determine acceleration rate by adjusting for the number of low frequencies - adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) - offset = self.rng.randint(0, round(adjusted_accel)) - - accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) - accel_samples = np.around(accel_samples).astype(np.uint) - mask[accel_samples] = True - - # Reshape the mask - mask_shape = [1 for _ in spatial_size] - mask_shape[-2] = num_cols - mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) - - return mask - - -def apply_mask( - data: Tensor, mask_func: Optional[MaskFunc] = None, mask: Optional[torch.Tensor] = None, seed: Optional[int] = None -) -> Tensor: - """ - Subsample given k-space by multiplying with a mask. - inputs: - data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where - dimensions -3 and -2 are the spatial dimensions, and the final dimension has size - 2 (for complex values). - mask_func (callable): A function that takes a shape (tuple of ints) and a random - number seed and returns a mask. - seed (int or 1-d array_like, optional): Seed for the random number generator. - outputs: - (tuple): tuple containing: - masked data (torch.Tensor): Subsampled k-space data - mask (torch.Tensor): The generated mask - """ - shape = np.array(data.shape) - shape[:-3] = 1 - if mask is None: - mask = mask_func(shape, seed) - return data * mask, mask diff --git a/monai/apps/reconstruction/test_fft_utils.py b/monai/apps/reconstruction/test_fft_utils.py index 4a1b3a367c..df4abc071d 100644 --- a/monai/apps/reconstruction/test_fft_utils.py +++ b/monai/apps/reconstruction/test_fft_utils.py @@ -11,7 +11,7 @@ import unittest -from fft_utils import fftn, ifftn +from fft_utils import fftn_centered, ifftn_centered from parameterized import parameterized from tests.utils import TEST_NDARRAYS, assert_allclose @@ -42,19 +42,20 @@ class TestFFT(unittest.TestCase): @parameterized.expand(TESTS) def test(self, test_data, res_data): - result = fftn(test_data, spatial_dims=2, is_complex=False) + result = fftn_centered(test_data, spatial_dims=2, is_complex=False) assert_allclose(result, res_data, type_test=True) @parameterized.expand(TESTS_CONSISTENCY) def test_consistency(self, test_data): - result = fftn(test_data, spatial_dims=2, is_complex=False) - result = ifftn(result, spatial_dims=2, is_complex=True) + result = fftn_centered(test_data, spatial_dims=2, is_complex=False) + result = ifftn_centered(result, spatial_dims=2, is_complex=True) result = (result[..., 0] ** 2 + result[..., 1] ** 2) ** 0.5 assert_allclose(result, test_data, type_test=False) @parameterized.expand(TESTS_CONSISTENCY_COMPLEX) def test_consistency_complex(self, test_data): - result = ifftn(fftn(test_data, spatial_dims=2), spatial_dims=2) + result = fftn_centered(test_data, spatial_dims=2) + result = ifftn_centered(result, spatial_dims=2) assert_allclose(result, test_data, type_test=False) From 962f5f8cdc81e51f9798f8d3fcb57b1277a937d7 Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Thu, 23 Jun 2022 13:12:54 -0700 Subject: [PATCH 05/17] apps.rst updated with fft_utils docstrings under the reconstruction module Signed-off-by: mersad95zd --- docs/source/apps.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 6ae15e6f64..cd1a3bdd3d 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -187,7 +187,8 @@ Applications -"Reconstruction" +`Reconstruction` +----------------- .. automodule:: monai.apps.reconstruction.fft_utils .. autofunction:: monai.apps.reconstruction.fft_utils.fftn_centered From 829992c400b77ea9ed2b011f22efdb674d05589d Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Thu, 23 Jun 2022 13:42:55 -0700 Subject: [PATCH 06/17] fft_utils docstring updated by adding dimension hins Signed-off-by: mersad95zd --- monai/apps/reconstruction/fft_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/apps/reconstruction/fft_utils.py b/monai/apps/reconstruction/fft_utils.py index 3458da57ea..e93dce700b 100644 --- a/monai/apps/reconstruction/fft_utils.py +++ b/monai/apps/reconstruction/fft_utils.py @@ -18,10 +18,12 @@ def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: """ Pytorch-based ifft for spatial_dims-dim signals. - This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifft.shift + This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift Args: - ksp: k-space data + ksp: k-space data that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) @@ -66,10 +68,12 @@ def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = T def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: """ Pytorch-based fft for spatial_dims-dim signals. - This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifft.shift + This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift Args: - im: image + im: image that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) From 8536cd316e2dab0c2d85fe56a2e2697d3dca765e Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Thu, 23 Jun 2022 13:51:33 -0700 Subject: [PATCH 07/17] fft_utils docstring updated by removing redundant output type Signed-off-by: mersad95zd --- monai/apps/reconstruction/fft_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/apps/reconstruction/fft_utils.py b/monai/apps/reconstruction/fft_utils.py index e93dce700b..e818772b02 100644 --- a/monai/apps/reconstruction/fft_utils.py +++ b/monai/apps/reconstruction/fft_utils.py @@ -23,12 +23,12 @@ def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = T Args: ksp: k-space data that can be 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or - 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) Returns: - Union[ndarray,Tensor] "out" which is the output image (inverse fourier of ksp) + "out" which is the output image (inverse fourier of ksp) Example: @@ -73,12 +73,12 @@ def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = Tru Args: im: image that can be 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or - 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) Returns: - Union[ndarray,Tensor] "out" which is the output kspace (fourier of im) + "out" which is the output kspace (fourier of im) Example: From ae346ee4d84d6f6ee8d5aab59722d04d1526b1b6 Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Thu, 23 Jun 2022 14:16:40 -0700 Subject: [PATCH 08/17] test_fft_utils.py moved to the tests folder Signed-off-by: mersad95zd --- {monai/apps/reconstruction => tests}/test_fft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename {monai/apps/reconstruction => tests}/test_fft_utils.py (96%) diff --git a/monai/apps/reconstruction/test_fft_utils.py b/tests/test_fft_utils.py similarity index 96% rename from monai/apps/reconstruction/test_fft_utils.py rename to tests/test_fft_utils.py index df4abc071d..875e8dbdc4 100644 --- a/monai/apps/reconstruction/test_fft_utils.py +++ b/tests/test_fft_utils.py @@ -11,7 +11,7 @@ import unittest -from fft_utils import fftn_centered, ifftn_centered +from monai.apps.reconstruction.fft_utils import fftn_centered, ifftn_centered from parameterized import parameterized from tests.utils import TEST_NDARRAYS, assert_allclose From 58086b3f26db71bd5259ececd3b5596144864882 Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Fri, 24 Jun 2022 14:31:35 -0700 Subject: [PATCH 09/17] created fft_utils_t, the torch-only version of fft_utils Signed-off-by: mersad95zd --- docs/source/apps.rst | 6 -- docs/source/data.rst | 9 ++ monai/data/fft_utils.py | 94 +++++++++++++++++++ .../blocks/fft_utils_t.py} | 30 +++--- tests/test_fft_utils.py | 2 +- 5 files changed, 115 insertions(+), 26 deletions(-) create mode 100644 monai/data/fft_utils.py rename monai/{apps/reconstruction/fft_utils.py => networks/blocks/fft_utils_t.py} (80%) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index cd1a3bdd3d..4e5af8fe0c 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -187,9 +187,3 @@ Applications -`Reconstruction` ------------------ - -.. automodule:: monai.apps.reconstruction.fft_utils -.. autofunction:: monai.apps.reconstruction.fft_utils.fftn_centered -.. autofunction:: monai.apps.reconstruction.fft_utils.ifftn_centered diff --git a/docs/source/data.rst b/docs/source/data.rst index 0de5e0c347..7d57221830 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -268,6 +268,13 @@ TestTimeAugmentation ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.TestTimeAugmentation +N-Dim Fourier Transform +~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.data.fft_utils +.. autofunction:: monai.data.fft_utils.fftn_centered +.. autofunction:: monai.data.fft_utils.ifftn_centered + + Meta Object ----------- @@ -326,3 +333,5 @@ Bounding box ------------ .. automodule:: monai.data.box_utils :members: + + diff --git a/monai/data/fft_utils.py b/monai/data/fft_utils.py new file mode 100644 index 0000000000..2d38beed5e --- /dev/null +++ b/monai/data/fft_utils.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.config.type_definitions import NdarrayOrTensor +from monai.networks.blocks.fft_utils_t import fftn_centered_t, ifftn_centered_t +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type + + +def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: + """ + Pytorch-based ifft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. This function calls monai.metworks.blocks.fft_utils_t.ifftn_centered_t. + This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift + + Args: + ksp: k-space data that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels) + + Returns: + "out" which is the output image (inverse fourier of ksp) + + Example: + + .. code-block:: python + + import torch + ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) + """ + # handle numpy format + ksp_t, *_ = convert_data_type(ksp, torch.Tensor) + + # compute ifftn + out_t = ifftn_centered_t(ksp_t, spatial_dims=spatial_dims, is_complex=is_complex) + + # handle numpy format + out, *_ = convert_to_dst_type(src=out_t, dst=ksp) + return out + + +def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: + """ + Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. This function calls monai.metworks.blocks.fft_utils_t.fftn_centered_t. + This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift + + Args: + im: image that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume) + is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels) + + Returns: + "out" which is the output kspace (fourier of im) + + Example: + + .. code-block:: python + + import torch + im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts + # output1 and output2 will be identical + output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm="ortho") + output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) ) + + output2 = fftn_centered(im, spatial_dims=2, is_complex=True) + """ + # handle numpy format + im_t, *_ = convert_data_type(im, torch.Tensor) + + # compute ifftn + out_t = fftn_centered_t(im_t, spatial_dims=spatial_dims, is_complex=is_complex) + + # handle numpy format + out, *_ = convert_to_dst_type(src=out_t, dst=im) + return out diff --git a/monai/apps/reconstruction/fft_utils.py b/monai/networks/blocks/fft_utils_t.py similarity index 80% rename from monai/apps/reconstruction/fft_utils.py rename to monai/networks/blocks/fft_utils_t.py index e818772b02..4031882d6e 100644 --- a/monai/apps/reconstruction/fft_utils.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -10,14 +10,13 @@ # limitations under the License. import torch +from torch import Tensor -from monai.config.type_definitions import NdarrayOrTensor -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type - -def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: +def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: """ - Pytorch-based ifft for spatial_dims-dim signals. + Pytorch-based ifft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift Args: @@ -42,8 +41,6 @@ def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = T output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) """ - # handle numpy format - ksp_t, *_ = convert_data_type(ksp, torch.Tensor) # define spatial dims to perform ifftshift, fftshift, and ifft shift = tuple(range(-spatial_dims, 0)) @@ -53,21 +50,20 @@ def ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = T dims = tuple(range(-spatial_dims, 0)) # apply ifft - x = torch.fft.ifftshift(ksp_t, dim=shift) + x = torch.fft.ifftshift(ksp, dim=shift) if is_complex: x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) - out_t = torch.fft.fftshift(x, dim=shift) + out = torch.fft.fftshift(x, dim=shift) - # handle numpy format - out, *_ = convert_to_dst_type(src=out_t, dst=ksp) return out -def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor: +def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: """ - Pytorch-based fft for spatial_dims-dim signals. + Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care + of the required ifft and fft shifts. This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift Args: @@ -92,8 +88,6 @@ def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = Tru output2 = fftn_centered(im, spatial_dims=2, is_complex=True) """ - # handle numpy format - im_t, *_ = convert_data_type(im, torch.Tensor) # define spatial dims to perform ifftshift, fftshift, and fft shift = tuple(range(-spatial_dims, 0)) @@ -103,13 +97,11 @@ def fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = Tru dims = tuple(range(-spatial_dims, 0)) # apply fft - x = torch.fft.ifftshift(im_t, dim=shift) + x = torch.fft.ifftshift(im, dim=shift) if is_complex: x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - out_t = torch.fft.fftshift(x, dim=shift) + out = torch.fft.fftshift(x, dim=shift) - # handle numpy format - out, *_ = convert_to_dst_type(src=out_t, dst=im) return out diff --git a/tests/test_fft_utils.py b/tests/test_fft_utils.py index 875e8dbdc4..d5e3a22eaa 100644 --- a/tests/test_fft_utils.py +++ b/tests/test_fft_utils.py @@ -11,9 +11,9 @@ import unittest -from monai.apps.reconstruction.fft_utils import fftn_centered, ifftn_centered from parameterized import parameterized +from monai.data.fft_utils import fftn_centered, ifftn_centered from tests.utils import TEST_NDARRAYS, assert_allclose # From 0c3f06703cba84a6eb50fe53b4cbdbc683d18e48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Jun 2022 21:35:13 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/apps.rst | 3 --- docs/source/data.rst | 2 -- 2 files changed, 5 deletions(-) diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 4e5af8fe0c..3bb85c9296 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -184,6 +184,3 @@ Applications :members: .. automodule:: monai.apps.detection.metrics.matching :members: - - - diff --git a/docs/source/data.rst b/docs/source/data.rst index 7d57221830..c3a70da74f 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -333,5 +333,3 @@ Bounding box ------------ .. automodule:: monai.data.box_utils :members: - - From f48dc5abbd3c0ef42c538a27a9d8b89bc904430d Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Sat, 25 Jun 2022 00:35:45 -0700 Subject: [PATCH 11/17] fft_utils_t updated with type ignore for mypy Signed-off-by: mersad95zd --- monai/networks/blocks/fft_utils_t.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 4031882d6e..aaba6af6a5 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -57,7 +57,7 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) out = torch.fft.fftshift(x, dim=shift) - return out + return out # type: ignore[no-any-return] def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: @@ -104,4 +104,4 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) out = torch.fft.fftshift(x, dim=shift) - return out + return out # type: ignore[no-any-return] From 7a61979708bf1e27c5d5cac8380f394216548c0a Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Sat, 25 Jun 2022 00:59:03 -0700 Subject: [PATCH 12/17] docs/source/networks.rst updated with fft_utils_t Signed-off-by: mersad95zd --- docs/source/networks.rst | 7 +++++++ monai/networks/blocks/fft_utils_t.py | 4 +--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 2164fe5d1b..c0f765f0ef 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -233,6 +233,13 @@ Blocks .. autoclass:: DVF2DDF :members: + +N-Dim Fourier Transform +~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: monai.networks.blocks.fft_utils_t +.. autofunction:: monai.networks.blocks.fft_utils_t.fftn_centered_t +.. autofunction:: monai.networks.blocks.fft_utils_t.ifftn_centered_t + Layers ------ diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index aaba6af6a5..3b3154c3c0 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -41,7 +41,6 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True) """ - # define spatial dims to perform ifftshift, fftshift, and ifft shift = tuple(range(-spatial_dims, 0)) if is_complex: @@ -63,7 +62,7 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: """ Pytorch-based fft for spatial_dims-dim signals. "centered" means this function automatically takes care - of the required ifft and fft shifts. + of the required ifft and fft shifts. This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift Args: @@ -88,7 +87,6 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T output2 = fftn_centered(im, spatial_dims=2, is_complex=True) """ - # define spatial dims to perform ifftshift, fftshift, and fft shift = tuple(range(-spatial_dims, 0)) if is_complex: From 4fd131fbcaa9f52bd385a26025fe46b85b107763 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Jun 2022 08:00:06 +0000 Subject: [PATCH 13/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/networks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index c0f765f0ef..3c7c5cec63 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -233,7 +233,7 @@ Blocks .. autoclass:: DVF2DDF :members: - + N-Dim Fourier Transform ~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: monai.networks.blocks.fft_utils_t From d4ce1fb61bba959dcdfe523768cf858464fb856b Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Sat, 25 Jun 2022 15:51:58 -0700 Subject: [PATCH 14/17] manual fix for fft_utils_t output data types Signed-off-by: mersad95zd --- monai/networks/blocks/fft_utils_t.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 3b3154c3c0..7a96b3d96f 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -12,6 +12,8 @@ import torch from torch import Tensor +from monai.utils.type_conversion import convert_data_type + def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: """ @@ -54,9 +56,9 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) - out = torch.fft.fftshift(x, dim=shift) + out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] - return out # type: ignore[no-any-return] + return out def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: @@ -100,6 +102,6 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - out = torch.fft.fftshift(x, dim=shift) + out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] - return out # type: ignore[no-any-return] + return out From 20645950dd1454e42b0a4bcaf2a1b975c5e09d21 Mon Sep 17 00:00:00 2001 From: mersad95zd Date: Mon, 27 Jun 2022 08:34:05 -0700 Subject: [PATCH 15/17] added support for older pytorch versions Signed-off-by: mersad95zd --- docs/source/data.rst | 1 - docs/source/networks.rst | 4 + monai/networks/blocks/fft_utils_t.py | 130 ++++++++++++++++++++++++++- 3 files changed, 130 insertions(+), 5 deletions(-) diff --git a/docs/source/data.rst b/docs/source/data.rst index c3a70da74f..eab4f867af 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -275,7 +275,6 @@ N-Dim Fourier Transform .. autofunction:: monai.data.fft_utils.ifftn_centered - Meta Object ----------- .. automodule:: monai.data.meta_obj diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 3c7c5cec63..a7b6d8cdc0 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -239,6 +239,10 @@ N-Dim Fourier Transform .. automodule:: monai.networks.blocks.fft_utils_t .. autofunction:: monai.networks.blocks.fft_utils_t.fftn_centered_t .. autofunction:: monai.networks.blocks.fft_utils_t.ifftn_centered_t +.. autofunction:: monai.networks.blocks.fft_utils_t.roll +.. autofunction:: monai.networks.blocks.fft_utils_t.roll_1d +.. autofunction:: monai.networks.blocks.fft_utils_t.fftshift +.. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift Layers ------ diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 7a96b3d96f..6fe8b47448 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -9,12 +9,118 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Sequence + import torch from torch import Tensor from monai.utils.type_conversion import convert_data_type +def roll_1d(x: Tensor, shift: int, shift_dim: int) -> Tensor: + """ + Similar to roll but for only one dim. + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift: the amount of shift along each of shift_dims dimension + shift_dim: the dimension over which the shift is applied + + Returns: + 1d-shifted version of x + + Note: + This function is called when fftshift and ifftshift are not available in the running pytorch version + """ + shift = shift % x.size(shift_dim) + if shift == 0: + return x + + left = x.narrow(shift_dim, 0, x.size(shift_dim) - shift) + right = x.narrow(shift_dim, x.size(shift_dim) - shift, shift) + + return torch.cat((right, left), dim=shift_dim) + + +def roll(x: Tensor, shift: Sequence[int], shift_dims: Sequence[int]) -> Tensor: + """ + Similar to np.roll but applies to PyTorch Tensors + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift: the amount of shift along each of shift_dims dimensions + shift_dims: dimensions over which the shift is applied + + Returns: + shifted version of x + + Note: + This function is called when fftshift and ifftshift are not available in the running pytorch version + """ + assert len(shift) == len(shift_dims) + for s, d in zip(shift, shift_dims): + x = roll_1d(x, s, d) + return x + + +def fftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor: + """ + Similar to np.fft.fftshift but applies to PyTorch Tensors + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift_dims: dimensions over which the shift is applied + + Returns: + fft-shifted version of x + + Note: + This function is called when fftshift is not available in the running pytorch version + """ + if shift_dims is None: + # for torch.jit.script based on the fastmri repository + shift_dims = [0] * (x.dim()) + for i in range(1, x.dim()): + shift_dims[i] = i + shift = [0] * len(shift_dims) + for i, dim_num in enumerate(shift_dims): + shift[i] = x.shape[dim_num] // 2 + return roll(x, shift, shift_dims) + + +def ifftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor: + """ + Similar to np.fft.ifftshift but applies to PyTorch Tensors + + Args: + x: input data (k-space or image) that can be + 1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or + 2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels. + shift_dims: dimensions over which the shift is applied + + Returns: + ifft-shifted version of x + + Note: + This function is called when ifftshift is not available in the running pytorch version + """ + if shift_dims is None: + # for torch.jit.script based on the fastmri repository + shift_dims = [0] * (x.dim()) + for i in range(1, x.dim()): + shift_dims[i] = i + shift = [0] * len(shift_dims) + for i, dim_num in enumerate(shift_dims): + shift[i] = (x.shape[dim_num] + 1) // 2 + return roll(x, shift, shift_dims) + + def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor: """ Pytorch-based ifft for spatial_dims-dim signals. "centered" means this function automatically takes care @@ -51,12 +157,20 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> dims = tuple(range(-spatial_dims, 0)) # apply ifft - x = torch.fft.ifftshift(ksp, dim=shift) + if hasattr(torch.fft, "ifftshift"): # ifftshift was added in pytorch 1.8 + x = torch.fft.ifftshift(ksp, dim=shift) + else: + x = ifftshift(ksp, shift) + if is_complex: x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho")) - out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] + + if hasattr(torch.fft, "fftshift"): + out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] + else: + out = convert_data_type(fftshift(x, shift), torch.Tensor)[0] return out @@ -97,11 +211,19 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T dims = tuple(range(-spatial_dims, 0)) # apply fft - x = torch.fft.ifftshift(im, dim=shift) + if hasattr(torch.fft, "ifftshift"): # ifftshift was added in pytorch 1.8 + x = torch.fft.ifftshift(im, dim=shift) + else: + x = ifftshift(im, shift) + if is_complex: x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho")) else: x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho")) - out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] + + if hasattr(torch.fft, "fftshift"): + out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0] + else: + out = convert_data_type(fftshift(x, shift), torch.Tensor)[0] return out From 41fb315b4014e9e39b299c84b29e464ff5661918 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 28 Jun 2022 21:08:35 +0100 Subject: [PATCH 16/17] fixes mypy Signed-off-by: Wenqi Li --- monai/data/box_utils.py | 2 +- monai/networks/blocks/feature_pyramid_network.py | 2 +- monai/transforms/utility/array.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py index f0f95dbc3e..47e7b6911d 100644 --- a/monai/data/box_utils.py +++ b/monai/data/box_utils.py @@ -621,7 +621,7 @@ def centers_in_boxes(centers: NdarrayOrTensor, boxes: NdarrayOrTensor, eps: floa min_center_to_border: np.ndarray = np.stack(center_to_border, axis=1).min(axis=1) return min_center_to_border > eps # array[bool] - return torch.stack(center_to_border, dim=1).to(COMPUTE_DTYPE).min(dim=1)[0] > eps # Tensor[bool] + return torch.stack(center_to_border, dim=1).to(COMPUTE_DTYPE).min(dim=1)[0] > eps # type: ignore def boxes_center_distance( diff --git a/monai/networks/blocks/feature_pyramid_network.py b/monai/networks/blocks/feature_pyramid_network.py index 2f7b903a19..2373cfc099 100644 --- a/monai/networks/blocks/feature_pyramid_network.py +++ b/monai/networks/blocks/feature_pyramid_network.py @@ -193,7 +193,7 @@ def __init__( conv_type_: Type[nn.Module] = Conv[Conv.CONV, spatial_dims] for m in self.modules(): if isinstance(m, conv_type_): - nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.kaiming_uniform_(m.weight, a=1) # type: ignore nn.init.constant_(m.bias, 0.0) # type: ignore if extra_blocks is not None: diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 0d5bafb026..d50505fe84 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -313,10 +313,10 @@ def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: if isinstance(img, torch.Tensor): outputs = list(torch.split(img, 1, self.dim)) else: - outputs = np.split(img, n_out, self.dim) + outputs = np.split(img, n_out, self.dim) # type: ignore if not self.keepdim: outputs = [o.squeeze(self.dim) for o in outputs] - return outputs + return outputs # type: ignore @deprecated(since="0.8", msg_suffix="please use `SplitDim` instead.") From 2a7b15065df76d38386ce3cadd5918971e3ce86a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 29 Jun 2022 08:45:09 +0100 Subject: [PATCH 17/17] update to remove assert Signed-off-by: Wenqi Li --- monai/networks/blocks/fft_utils_t.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py index 6fe8b47448..0d6b99d7e1 100644 --- a/monai/networks/blocks/fft_utils_t.py +++ b/monai/networks/blocks/fft_utils_t.py @@ -61,7 +61,8 @@ def roll(x: Tensor, shift: Sequence[int], shift_dims: Sequence[int]) -> Tensor: Note: This function is called when fftshift and ifftshift are not available in the running pytorch version """ - assert len(shift) == len(shift_dims) + if len(shift) != len(shift_dims): + raise ValueError(f"len(shift) != len(shift_dims), got f{len(shift)} and f{len(shift_dims)}.") for s, d in zip(shift, shift_dims): x = roll_1d(x, s, d) return x @@ -152,7 +153,8 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> # define spatial dims to perform ifftshift, fftshift, and ifft shift = tuple(range(-spatial_dims, 0)) if is_complex: - assert ksp.shape[-1] == 2 + if ksp.shape[-1] != 2: + raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).") shift = tuple(range(-spatial_dims - 1, -1)) dims = tuple(range(-spatial_dims, 0)) @@ -206,7 +208,8 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T # define spatial dims to perform ifftshift, fftshift, and fft shift = tuple(range(-spatial_dims, 0)) if is_complex: - assert im.shape[-1] == 2 + if im.shape[-1] != 2: + raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).") shift = tuple(range(-spatial_dims - 1, -1)) dims = tuple(range(-spatial_dims, 0))