-
Notifications
You must be signed in to change notification settings - Fork 1.6k
4599 mri ssim loss #4600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
4599 mri ssim loss #4600
Changes from all commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
e0d20e5
mri utils added
mersad95zd aecedaa
fft_utils with its unit test added
mersad95zd 41c93b0
Merge branch 'dev' into 00-mri-utils
Can-Zhao 1e0e39c
fft_utils updated with monai data converter
mersad95zd 34b9da9
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd a452ddc
updated fft_util's docstring
mersad95zd 962f5f8
apps.rst updated with fft_utils docstrings under the reconstruction m…
mersad95zd 829992c
fft_utils docstring updated by adding dimension hins
mersad95zd 8536cd3
fft_utils docstring updated by removing redundant output type
mersad95zd ae346ee
test_fft_utils.py moved to the tests folder
mersad95zd 41aa0c7
Merge branch 'dev' into 00-mri-utils
mersad95zd 89f5c99
Merge branch 'dev' into 00-mri-utils
mersad95zd 58086b3
created fft_utils_t, the torch-only version of fft_utils
mersad95zd 9928973
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd 0c3f067
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f48dc5a
fft_utils_t updated with type ignore for mypy
mersad95zd e4773b4
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd 7a61979
docs/source/networks.rst updated with fft_utils_t
mersad95zd 4fd131f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d4ce1fb
manual fix for fft_utils_t output data types
mersad95zd eb2b420
Merge branch '00-mri-utils' of https://github.com/mersad95zd/MONAI in…
mersad95zd 4793885
initial mri_transforms added under apps/reconstruction
mersad95zd 5b4e2ad
PR1 files deleted
mersad95zd 378e63e
Merge branch '4587-mri-transforms' of https://github.com/mersad95zd/M…
mersad95zd a986290
with PR1 files to avoid error
mersad95zd 3941149
Merge branch '4587-mri-transforms' of https://github.com/mersad95zd/M…
mersad95zd b78b541
PR1 included
mersad95zd 0be3d34
PR1 files finally removed
mersad95zd cc42c3b
Merge branch 'dev' into 4587-mri-transforms
mersad95zd 00c9a7c
putting test_mri_transforms under apps for now
mersad95zd 026f141
Merge branch '4587-mri-transforms' of https://github.com/mersad95zd/M…
mersad95zd f813784
isort fixed
mersad95zd 3b45ee6
initial commit of ssim loss
mersad95zd aa85e34
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd 3228ee4
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd 6109d38
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd 4c5d23c
format revision and adding ssim_metric
mersad95zd 46f7ebc
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd 813de5b
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd 9386548
class signature errors resolved for ssim_metric
mersad95zd c3bb6f4
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd dec52e6
ssim_metric now calls ssim_loss; moved metric to metrics; moved loss …
mersad95zd 8eba52b
shape added to loss function and metric __call__
mersad95zd 4fe5d39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 09d0b3c
minor issues resolved
mersad95zd c212498
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd 1c599b4
device support added to ssim_loss
mersad95zd 1bdb915
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd 17db867
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd 5b941c1
Merge branch 'dev' into 4599-mri-ssim-loss
mersad95zd ec03dd8
fixed minor inline issues
mersad95zd a99e2cf
Merge branch '4599-mri-ssim-loss' of https://github.com/mersad95zd/MO…
mersad95zd df76151
docs issue resolved
mersad95zd File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import nn | ||
|
|
||
| from monai.utils.type_conversion import convert_to_dst_type | ||
|
|
||
|
|
||
| class SSIMLoss(nn.Module): | ||
| """ | ||
| Build a Pytorch version of the SSIM loss function based on the original formula of SSIM | ||
|
|
||
| Modified and adopted from: | ||
| https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py | ||
|
|
||
| For more info, visit | ||
| https://vicuesoft.com/glossary/term/ssim-ms-ssim/ | ||
|
|
||
| SSIM reference paper: | ||
| Wang, Zhou, et al. "Image quality assessment: from error visibility to structural | ||
| similarity." IEEE transactions on image processing 13.4 (2004): 600-612. | ||
| """ | ||
|
|
||
| def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2): | ||
| """ | ||
| Args: | ||
| win_size: gaussian weighting window size | ||
| k1: stability constant used in the luminance denominator | ||
| k2: stability constant used in the contrast denominator | ||
| spatial_dims: if 2, input shape is expected to be (B,C,W,H). if 3, it is expected to be (B,C,W,H,D) | ||
| """ | ||
| super().__init__() | ||
| self.win_size = win_size | ||
| self.k1, self.k2 = k1, k2 | ||
| self.spatial_dims = spatial_dims | ||
| self.register_buffer( | ||
| "w", torch.ones([1, 1] + [win_size for _ in range(spatial_dims)]) / win_size**spatial_dims | ||
| ) | ||
| self.cov_norm = (win_size**2) / (win_size**2 - 1) | ||
|
|
||
| def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. | ||
| A fastMRI sample should use the 2D format with C being the number of slices. | ||
| y: second sample (e.g., the reconstructed image). It has similar shape as x. | ||
| data_range: dynamic range of the data | ||
|
|
||
| Returns: | ||
| 1-ssim_value (recall this is meant to be a loss function) | ||
|
|
||
| Example: | ||
| .. code-block:: python | ||
|
|
||
| import torch | ||
| x = torch.ones([1,1,10,10])/2 | ||
| y = torch.ones([1,1,10,10])/2 | ||
| data_range = x.max().unsqueeze(0) | ||
| # the following line should print 1.0 (or 0.9999) | ||
| print(1-SSIMLoss(spatial_dims=2)(x,y,data_range)) | ||
| """ | ||
| data_range = data_range[(None,) * (self.spatial_dims + 2)] | ||
| # determine whether to work with 2D convolution or 3D | ||
| conv = getattr(F, f"conv{self.spatial_dims}d") | ||
| w = convert_to_dst_type(src=self.w, dst=x)[0] | ||
|
|
||
| c1 = (self.k1 * data_range) ** 2 # stability constant for luminance | ||
| c2 = (self.k2 * data_range) ** 2 # stability constant for contrast | ||
| ux = conv(x, w) # mu_x | ||
| uy = conv(y, w) # mu_y | ||
| uxx = conv(x * x, w) # mu_x^2 | ||
| uyy = conv(y * y, w) # mu_y^2 | ||
| uxy = conv(x * y, w) # mu_xy | ||
| vx = self.cov_norm * (uxx - ux * ux) # sigma_x | ||
| vy = self.cov_norm * (uyy - uy * uy) # sigma_y | ||
| vxy = self.cov_norm * (uxy - ux * uy) # sigma_xy | ||
|
|
||
| numerator = (2 * ux * uy + c1) * (2 * vxy + c2) | ||
| denom = (ux**2 + uy**2 + c1) * (vx + vy + c2) | ||
| ssim_value = numerator / denom | ||
| loss: torch.Tensor = 1 - ssim_value.mean() | ||
| return loss | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from parameterized import parameterized | ||
|
|
||
| from monai.losses.ssim_loss import SSIMLoss | ||
|
|
||
| x = torch.ones([1, 1, 10, 10]) / 2 | ||
| y1 = torch.ones([1, 1, 10, 10]) / 2 | ||
| y2 = torch.zeros([1, 1, 10, 10]) | ||
| data_range = x.max().unsqueeze(0) | ||
| TESTS2D = [] | ||
| for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: | ||
| TESTS2D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) | ||
| TESTS2D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) | ||
|
|
||
| x = torch.ones([1, 1, 10, 10, 10]) / 2 | ||
| y1 = torch.ones([1, 1, 10, 10, 10]) / 2 | ||
| y2 = torch.zeros([1, 1, 10, 10, 10]) | ||
| data_range = x.max().unsqueeze(0) | ||
| TESTS3D = [] | ||
| for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: | ||
| TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device))) | ||
| TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device))) | ||
|
|
||
|
|
||
| class TestSSIMLoss(unittest.TestCase): | ||
| @parameterized.expand(TESTS2D) | ||
| def test2d(self, x, y, drange, res): | ||
| result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange) | ||
| self.assertTrue(isinstance(result, torch.Tensor)) | ||
| self.assertTrue(torch.abs(res - result).item() < 0.001) | ||
|
|
||
| @parameterized.expand(TESTS3D) | ||
| def test3d(self, x, y, drange, res): | ||
| result = 1 - SSIMLoss(spatial_dims=3)(x, y, drange) | ||
| self.assertTrue(isinstance(result, torch.Tensor)) | ||
| self.assertTrue(torch.abs(res - result).item() < 0.001) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from parameterized import parameterized | ||
|
|
||
| from monai.metrics.regression import SSIMMetric | ||
|
|
||
| x = torch.ones([1, 1, 10, 10]) / 2 | ||
| y1 = torch.ones([1, 1, 10, 10]) / 2 | ||
| y2 = torch.zeros([1, 1, 10, 10]) | ||
| data_range = x.max().unsqueeze(0) | ||
| TESTS2D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))] | ||
|
|
||
| x = torch.ones([1, 1, 10, 10, 10]) / 2 | ||
| y1 = torch.ones([1, 1, 10, 10, 10]) / 2 | ||
| y2 = torch.zeros([1, 1, 10, 10, 10]) | ||
| data_range = x.max().unsqueeze(0) | ||
| TESTS3D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))] | ||
|
|
||
|
|
||
| class TestSSIMMetric(unittest.TestCase): | ||
| @parameterized.expand(TESTS2D) | ||
| def test2d(self, x, y, drange, res): | ||
| result = SSIMMetric(data_range=drange, spatial_dims=2)._compute_metric(x, y) | ||
| self.assertTrue(isinstance(result, torch.Tensor)) | ||
| self.assertTrue(torch.abs(res - result).item() < 0.001) | ||
|
|
||
| @parameterized.expand(TESTS3D) | ||
| def test3d(self, x, y, drange, res): | ||
| result = SSIMMetric(data_range=drange, spatial_dims=3)._compute_metric(x, y) | ||
| self.assertTrue(isinstance(result, torch.Tensor)) | ||
| self.assertTrue(torch.abs(res - result).item() < 0.001) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.