Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
e0d20e5
mri utils added
mersad95zd Jun 21, 2022
aecedaa
fft_utils with its unit test added
mersad95zd Jun 22, 2022
41c93b0
Merge branch 'dev' into 00-mri-utils
Can-Zhao Jun 23, 2022
1e0e39c
fft_utils updated with monai data converter
mersad95zd Jun 23, 2022
34b9da9
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd Jun 23, 2022
a452ddc
updated fft_util's docstring
mersad95zd Jun 23, 2022
962f5f8
apps.rst updated with fft_utils docstrings under the reconstruction m…
mersad95zd Jun 23, 2022
829992c
fft_utils docstring updated by adding dimension hins
mersad95zd Jun 23, 2022
8536cd3
fft_utils docstring updated by removing redundant output type
mersad95zd Jun 23, 2022
ae346ee
test_fft_utils.py moved to the tests folder
mersad95zd Jun 23, 2022
41aa0c7
Merge branch 'dev' into 00-mri-utils
mersad95zd Jun 24, 2022
89f5c99
Merge branch 'dev' into 00-mri-utils
mersad95zd Jun 24, 2022
58086b3
created fft_utils_t, the torch-only version of fft_utils
mersad95zd Jun 24, 2022
9928973
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd Jun 24, 2022
0c3f067
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2022
f48dc5a
fft_utils_t updated with type ignore for mypy
mersad95zd Jun 25, 2022
e4773b4
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd Jun 25, 2022
7a61979
docs/source/networks.rst updated with fft_utils_t
mersad95zd Jun 25, 2022
4fd131f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2022
d4ce1fb
manual fix for fft_utils_t output data types
mersad95zd Jun 25, 2022
eb2b420
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd Jun 25, 2022
4793885
initial mri_transforms added under apps/reconstruction
mersad95zd Jun 26, 2022
5b4e2ad
PR1 files deleted
mersad95zd Jun 25, 2022
378e63e
Merge branch '4587-mri-transforms' of https://github.com/mersad95zd/M…
mersad95zd Jun 27, 2022
a986290
with PR1 files to avoid error
mersad95zd Jun 27, 2022
3941149
Merge branch '4587-mri-transforms' of https://github.com/mersad95zd/M…
mersad95zd Jun 27, 2022
b78b541
PR1 included
mersad95zd Jun 27, 2022
0be3d34
PR1 files finally removed
mersad95zd Jun 27, 2022
cc42c3b
Merge branch 'dev' into 4587-mri-transforms
mersad95zd Jun 27, 2022
00c9a7c
putting test_mri_transforms under apps for now
mersad95zd Jun 27, 2022
026f141
Merge branch '4587-mri-transforms' of https://github.com/mersad95zd/M…
mersad95zd Jun 27, 2022
f813784
isort fixed
mersad95zd Jun 27, 2022
3b45ee6
initial commit of ssim loss
mersad95zd Jun 27, 2022
aa85e34
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd Jun 28, 2022
3228ee4
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd Jun 29, 2022
6109d38
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd Jun 30, 2022
4c5d23c
format revision and adding ssim_metric
mersad95zd Jun 30, 2022
46f7ebc
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd Jun 30, 2022
813de5b
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd Jun 30, 2022
9386548
class signature errors resolved for ssim_metric
mersad95zd Jun 30, 2022
c3bb6f4
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd Jun 30, 2022
dec52e6
ssim_metric now calls ssim_loss; moved metric to metrics; moved loss …
mersad95zd Jul 1, 2022
8eba52b
shape added to loss function and metric __call__
mersad95zd Jul 1, 2022
4fe5d39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2022
09d0b3c
minor issues resolved
mersad95zd Jul 1, 2022
c212498
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd Jul 1, 2022
1c599b4
device support added to ssim_loss
mersad95zd Jul 1, 2022
1bdb915
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd Jul 2, 2022
17db867
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd Jul 5, 2022
5b941c1
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd Jul 8, 2022
ec03dd8
fixed minor inline issues
mersad95zd Jul 8, 2022
a99e2cf
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd Jul 8, 2022
df76151
docs issue resolved
mersad95zd Jul 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ Registration Losses
.. autoclass:: GlobalMutualInformationLoss
:members:

Reconstruction Losses
---------------------

`SSIMLoss`
~~~~~~~~~~
.. autoclass:: monai.losses.ssim_loss.SSIMLoss
:members:


Loss Wrappers
-------------

Expand Down
4 changes: 4 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ Metrics
.. autoclass:: PSNRMetric
:members:

`Structural similarity index measure`
-------------------------------------
.. autoclass:: monai.metrics.regression.SSIMMetric

`Cumulative average`
--------------------
.. autoclass:: CumulativeAverage
Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .spatial_mask import MaskedLoss
from .ssim_loss import SSIMLoss
from .tversky import TverskyLoss
from .unified_focal_loss import AsymmetricUnifiedFocalLoss
93 changes: 93 additions & 0 deletions monai/losses/ssim_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn.functional as F
from torch import nn

from monai.utils.type_conversion import convert_to_dst_type


class SSIMLoss(nn.Module):
Comment thread
mersad95zd marked this conversation as resolved.
"""
Build a Pytorch version of the SSIM loss function based on the original formula of SSIM

Modified and adopted from:
https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py

For more info, visit
https://vicuesoft.com/glossary/term/ssim-ms-ssim/

SSIM reference paper:
Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
similarity." IEEE transactions on image processing 13.4 (2004): 600-612.
"""

def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2):
"""
Args:
win_size: gaussian weighting window size
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
spatial_dims: if 2, input shape is expected to be (B,C,W,H). if 3, it is expected to be (B,C,W,H,D)
"""
super().__init__()
self.win_size = win_size
self.k1, self.k2 = k1, k2
self.spatial_dims = spatial_dims
self.register_buffer(
"w", torch.ones([1, 1] + [win_size for _ in range(spatial_dims)]) / win_size**spatial_dims
)
self.cov_norm = (win_size**2) / (win_size**2 - 1)

def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor:
"""
Args:
x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
A fastMRI sample should use the 2D format with C being the number of slices.
y: second sample (e.g., the reconstructed image). It has similar shape as x.
data_range: dynamic range of the data

Returns:
1-ssim_value (recall this is meant to be a loss function)

Example:
.. code-block:: python

import torch
x = torch.ones([1,1,10,10])/2
y = torch.ones([1,1,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
"""
data_range = data_range[(None,) * (self.spatial_dims + 2)]
# determine whether to work with 2D convolution or 3D
conv = getattr(F, f"conv{self.spatial_dims}d")
w = convert_to_dst_type(src=self.w, dst=x)[0]

c1 = (self.k1 * data_range) ** 2 # stability constant for luminance
c2 = (self.k2 * data_range) ** 2 # stability constant for contrast
ux = conv(x, w) # mu_x
uy = conv(y, w) # mu_y
uxx = conv(x * x, w) # mu_x^2
uyy = conv(y * y, w) # mu_y^2
uxy = conv(x * y, w) # mu_xy
vx = self.cov_norm * (uxx - ux * ux) # sigma_x
vy = self.cov_norm * (uyy - uy * uy) # sigma_y
vxy = self.cov_norm * (uxy - ux * uy) # sigma_xy

numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
denom = (ux**2 + uy**2 + c1) * (vx + vy + c2)
ssim_value = numerator / denom
loss: torch.Tensor = 1 - ssim_value.mean()
return loss
2 changes: 1 addition & 1 deletion monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
from .meandice import DiceMetric, compute_meandice
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric, SSIMMetric
from .rocauc import ROCAUCMetric, compute_roc_auc
from .surface_dice import SurfaceDiceMetric, compute_surface_dice
from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance
Expand Down
62 changes: 62 additions & 0 deletions monai/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from monai.losses.ssim_loss import SSIMLoss
from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction

Expand Down Expand Up @@ -224,3 +225,64 @@ def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func) -> t
# reduction of batch handled inside __call__() using do_metric_reduction() in respective calling class
flt = partial(torch.flatten, start_dim=1)
return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True)


class SSIMMetric(RegressionMetric):
Comment thread
mersad95zd marked this conversation as resolved.
r"""
Comment thread
wyli marked this conversation as resolved.
Build a Pytorch version of the SSIM metric based on the original formula of SSIM

.. math::
\operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \
\mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}

For more info, visit
https://vicuesoft.com/glossary/term/ssim-ms-ssim/

Modified and adopted from:
https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py

SSIM reference paper:
Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
similarity." IEEE transactions on image processing 13.4 (2004): 600-612.

Args:
data_range: dynamic range of the data
win_size: gaussian weighting window size
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
spatial_dims: if 2, input shape is expected to be (B,C,W,H). if 3, it is expected to be (B,C,W,H,D)
"""

def __init__(
self, data_range: torch.Tensor, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2
):
super().__init__()
self.data_range = data_range
self.win_size = win_size
self.k1, self.k2 = k1, k2
self.spatial_dims = spatial_dims

def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Args:
x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
A fastMRI sample should use the 2D format with C being the number of slices.
y: second sample (e.g., the reconstructed image). It has similar shape as x

Returns:
ssim_value

Example:
.. code-block:: python

import torch
x = torch.ones([1,1,10,10])/2 # ground truth
y = torch.ones([1,1,10,10])/2 # prediction
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(SSIMMetric(data_range=data_range,spatial_dims=2)._compute_metric(x,y))
"""
ssim_value: torch.Tensor = 1 - SSIMLoss(self.win_size, self.k1, self.k2, self.spatial_dims)(
x, y, self.data_range
)
return ssim_value
53 changes: 53 additions & 0 deletions tests/test_ssim_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.losses.ssim_loss import SSIMLoss

x = torch.ones([1, 1, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS2D = []
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
TESTS2D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device)))
TESTS2D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device)))

x = torch.ones([1, 1, 10, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS3D = []
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device)))
TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device)))


class TestSSIMLoss(unittest.TestCase):
@parameterized.expand(TESTS2D)
def test2d(self, x, y, drange, res):
result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)

@parameterized.expand(TESTS3D)
def test3d(self, x, y, drange, res):
result = 1 - SSIMLoss(spatial_dims=3)(x, y, drange)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)


if __name__ == "__main__":
unittest.main()
47 changes: 47 additions & 0 deletions tests/test_ssim_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.metrics.regression import SSIMMetric

x = torch.ones([1, 1, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS2D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))]

x = torch.ones([1, 1, 10, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS3D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))]


class TestSSIMMetric(unittest.TestCase):
@parameterized.expand(TESTS2D)
def test2d(self, x, y, drange, res):
result = SSIMMetric(data_range=drange, spatial_dims=2)._compute_metric(x, y)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)

@parameterized.expand(TESTS3D)
def test3d(self, x, y, drange, res):
result = SSIMMetric(data_range=drange, spatial_dims=3)._compute_metric(x, y)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)


if __name__ == "__main__":
unittest.main()