From 293a57dfa764be7f47383045fae4cc7aaf39158d Mon Sep 17 00:00:00 2001 From: vnath Date: Sun, 7 Nov 2021 19:45:40 -0600 Subject: [PATCH 1/7] Contrastive Loss added, first draft Signed-off-by: vnath --- monai/losses/contrastive.py | 97 +++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 monai/losses/contrastive.py diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py new file mode 100644 index 0000000000..e580825dd0 --- /dev/null +++ b/monai/losses/contrastive.py @@ -0,0 +1,97 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Callable, List, Optional, Union + +import torch +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + +from monai.networks import one_hot +from monai.utils import LossReduction + +class ContrasiveLoss(_Loss): + + """ + Compute the Contrastive loss defined in: + + Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International + conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v119/chen20j.html) + + Adapted from: + https://github.com/Sara-Ahmed/SiT/blob/1aacd6adcd39b71efc903d16b4e9095b97dda76f/losses.py#L5 + + """ + + def __init__( + self, + normalize: bool = True, + temperature: float = 0.5, + batch_size: int = 1, + ) -> None: + """ + Args: + normalize: If True, input feature vector is normalized along the vector (B, F). F will be normalized + temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5. + + Raises: + TypeError: When ``other_act`` is not an ``Optional[Callable]``. + ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. + Incompatible values. + + """ + self.batch_size = batch_size + self.normalize = normalize + self.temperature = temperature + self.negatives_mask = torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=bool) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B[F]. + target: the shape should be B[F]. + + Raises: + ValueError: When ``self.reduction`` is not one of ["sum", "none"]. + """ + if target.shape != input.shape: + raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + if self.normalize: + norm_i = F.normalize(input, dim=1) + norm_j = F.normalize(target, dim=1) + + else: + norm_i = input + norm_j = target + + repr = torch.cat([norm_i, norm_j], dim=0) + sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) + + sim_ij = torch.diag(sim_matrix, self.batch_size) + sim_ji = torch.diag(sim_matrix, -self.batch_size) + + positives = torch.cat([sim_ij, sim_ji], dim=0) + + nominator = torch.exp(positives / self.temperature) + denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature) + + loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) + + if self.reduction == LossReduction.SUM.value: + return torch.sum(loss_partial) / (2 * self.batch_size) + raise ValueError(f'Unsupported reduction: {self.reduction}, ' + f'available options are ["mean", "sum", "none"].') + + + + From 14a58d7462fcbd839741473f91dbeb545be7baa8 Mon Sep 17 00:00:00 2001 From: vnath Date: Mon, 8 Nov 2021 20:27:12 -0600 Subject: [PATCH 2/7] Almost all requirements covered Signed-off-by: vnath --- docs/source/losses.rst | 5 +++ monai/losses/__init__.py | 1 + monai/losses/contrastive.py | 47 +++++++++++--------------- tests/min_tests.py | 1 + tests/test_contrastive_loss.py | 62 ++++++++++++++++++++++++++++++++++ 5 files changed, 88 insertions(+), 28 deletions(-) create mode 100644 tests/test_contrastive_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index fc7c302ea3..dfd8ce2ddb 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -63,6 +63,11 @@ Segmentation Losses .. autoclass:: TverskyLoss :members: +`ContrastiveLoss` +~~~~~~~~~~~~~~~~~ +.. autoclass:: ContrastiveLoss + :members: + Registration Losses ------------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 3e307fed22..3eca68cc4f 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .contrastive import ContrastiveLoss from .deform import BendingEnergyLoss from .dice import ( Dice, diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index e580825dd0..273814c53b 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -9,17 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from typing import Callable, List, Optional, Union +from typing import Union import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss -from monai.networks import one_hot from monai.utils import LossReduction -class ContrasiveLoss(_Loss): + +class ContrastiveLoss(_Loss): """ Compute the Contrastive loss defined in: @@ -33,26 +32,22 @@ class ContrasiveLoss(_Loss): """ def __init__( - self, - normalize: bool = True, - temperature: float = 0.5, - batch_size: int = 1, + self, temperature: float = 0.5, batch_size: int = 1, reduction: Union[LossReduction, str] = LossReduction.SUM ) -> None: """ Args: - normalize: If True, input feature vector is normalized along the vector (B, F). F will be normalized temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5. Raises: - TypeError: When ``other_act`` is not an ``Optional[Callable]``. - ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. - Incompatible values. + AssertionError: When an input of dimension length > 2 is passed + AssertionError: When input and target are of different shapes """ + super().__init__(reduction=LossReduction(reduction).value) + self.batch_size = batch_size - self.normalize = normalize self.temperature = temperature - self.negatives_mask = torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=bool) + self.negatives_mask = torch.eye(self.batch_size * 2, self.batch_size * 2) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -63,16 +58,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["sum", "none"]. """ + if len(target.shape) > 2 or len(input.shape) > 2: + raise AssertionError( + f"Either target or input has dimensions greater than 2 where target " + f"shape is ({target.shape}) and input shape is ({input.shape})" + ) + if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - if self.normalize: - norm_i = F.normalize(input, dim=1) - norm_j = F.normalize(target, dim=1) - - else: - norm_i = input - norm_j = target + norm_i = F.normalize(input, dim=1) + norm_j = F.normalize(target, dim=1) repr = torch.cat([norm_i, norm_j], dim=0) sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) @@ -83,15 +79,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: positives = torch.cat([sim_ij, sim_ji], dim=0) nominator = torch.exp(positives / self.temperature) - denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature) + denominator = self.negatives_mask * torch.exp(sim_matrix / self.temperature) loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) if self.reduction == LossReduction.SUM.value: return torch.sum(loss_partial) / (2 * self.batch_size) - raise ValueError(f'Unsupported reduction: {self.reduction}, ' - f'available options are ["mean", "sum", "none"].') - - - - + raise ValueError(f"Unsupported reduction: {self.reduction}, " f'available options are ["mean", "sum", "none"].') diff --git a/tests/min_tests.py b/tests/min_tests.py index 2dbc24533d..d63b9215f6 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -33,6 +33,7 @@ def run_testsuit(): "test_cachedataset_parallel", "test_cachedataset_persistent_workers", "test_cachentransdataset", + "test_contrastive_loss", "test_csv_dataset", "test_csv_iterable_dataset", "test_dataset", diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py new file mode 100644 index 0000000000..a01fd20ed4 --- /dev/null +++ b/tests/test_contrastive_loss.py @@ -0,0 +1,62 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import ContrastiveLoss + +TEST_CASES = [ + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 1}, + {"input": torch.tensor([[1.0, 1.0, 0.0, 0.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 0.0, + ], + [ # shape: (2, 4), (2, 4) + {"temperature": 0.5, "batch_size": 2}, + { + "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 0.0, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 1}, + {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 2.0, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.05, "batch_size": 1}, + {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, + 20.0, + ], +] + + +class TestContrastiveLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + contrastiveloss = ContrastiveLoss(**input_param) + result = contrastiveloss(**input_data) + print(result) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + def test_ill_shape(self): + loss = ContrastiveLoss(temperature=0.5, batch_size=1) + with self.assertRaisesRegex(AssertionError, ""): + loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + + +if __name__ == "__main__": + unittest.main() From cb0c072f38cf987da113a0de666c26d2cd7809bb Mon Sep 17 00:00:00 2001 From: vnath Date: Tue, 9 Nov 2021 16:09:39 -0600 Subject: [PATCH 3/7] Review Ready Signed-off-by: vnath --- monai/losses/contrastive.py | 6 ++++-- tests/test_contrastive_loss.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index 273814c53b..44073e1246 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -47,7 +47,6 @@ def __init__( self.batch_size = batch_size self.temperature = temperature - self.negatives_mask = torch.eye(self.batch_size * 2, self.batch_size * 2) def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -70,6 +69,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: norm_i = F.normalize(input, dim=1) norm_j = F.normalize(target, dim=1) + negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) + negatives_mask = torch.tensor(negatives_mask, dtype=torch.float) + repr = torch.cat([norm_i, norm_j], dim=0) sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) @@ -79,7 +81,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: positives = torch.cat([sim_ij, sim_ji], dim=0) nominator = torch.exp(positives / self.temperature) - denominator = self.negatives_mask * torch.exp(sim_matrix / self.temperature) + denominator = negatives_mask * torch.exp(sim_matrix / self.temperature) loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py index a01fd20ed4..f7e3faf80f 100644 --- a/tests/test_contrastive_loss.py +++ b/tests/test_contrastive_loss.py @@ -29,17 +29,25 @@ "input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), }, - 0.0, + 1.0986, + ], + [ # shape: (1, 4), (1, 4) + {"temperature": 0.5, "batch_size": 2}, + { + "input": torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 1.0, 0.0, 0.0]]), + "target": torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]), + }, + 0.8719, ], [ # shape: (1, 4), (1, 4) {"temperature": 0.5, "batch_size": 1}, {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, - 2.0, + 0.0, ], [ # shape: (1, 4), (1, 4) {"temperature": 0.05, "batch_size": 1}, {"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])}, - 20.0, + 0.0, ], ] @@ -49,7 +57,6 @@ class TestContrastiveLoss(unittest.TestCase): def test_result(self, input_param, input_data, expected_val): contrastiveloss = ContrastiveLoss(**input_param) result = contrastiveloss(**input_data) - print(result) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) def test_ill_shape(self): From 54905150272d19674d76c6e505b9fd57b2e98c67 Mon Sep 17 00:00:00 2001 From: vnath Date: Wed, 10 Nov 2021 21:12:18 -0600 Subject: [PATCH 4/7] CL Loss needs negative mask to be on the device, fix commited Signed-off-by: vnath --- monai/losses/contrastive.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index 44073e1246..590e418879 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -32,7 +32,11 @@ class ContrastiveLoss(_Loss): """ def __init__( - self, temperature: float = 0.5, batch_size: int = 1, reduction: Union[LossReduction, str] = LossReduction.SUM + self, + temperature: float = 0.5, + batch_size: int = 1, + device: str = "cpu", + reduction: Union[LossReduction, str] = LossReduction.SUM, ) -> None: """ Args: @@ -47,6 +51,7 @@ def __init__( self.batch_size = batch_size self.temperature = temperature + self.device = device def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -70,6 +75,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: norm_j = F.normalize(target, dim=1) negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) + negatives_mask = negatives_mask.to(self.device) negatives_mask = torch.tensor(negatives_mask, dtype=torch.float) repr = torch.cat([norm_i, norm_j], dim=0) From 5c41d6446bc127c522b8460f40f8aec34a9848f2 Mon Sep 17 00:00:00 2001 From: vnath Date: Thu, 11 Nov 2021 14:26:47 -0600 Subject: [PATCH 5/7] Cuda test added to CL loss test file Signed-off-by: vnath --- monai/losses/contrastive.py | 9 ++------- tests/test_contrastive_loss.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index 590e418879..b091fd76e9 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -32,11 +32,7 @@ class ContrastiveLoss(_Loss): """ def __init__( - self, - temperature: float = 0.5, - batch_size: int = 1, - device: str = "cpu", - reduction: Union[LossReduction, str] = LossReduction.SUM, + self, temperature: float = 0.5, batch_size: int = 1, reduction: Union[LossReduction, str] = LossReduction.SUM ) -> None: """ Args: @@ -51,7 +47,6 @@ def __init__( self.batch_size = batch_size self.temperature = temperature - self.device = device def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -75,7 +70,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: norm_j = F.normalize(target, dim=1) negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) - negatives_mask = negatives_mask.to(self.device) + negatives_mask = torch.clone(torch.as_tensor(negatives_mask)).to(input.device) negatives_mask = torch.tensor(negatives_mask, dtype=torch.float) repr = torch.cat([norm_i, norm_j], dim=0) diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py index f7e3faf80f..b9caecce65 100644 --- a/tests/test_contrastive_loss.py +++ b/tests/test_contrastive_loss.py @@ -64,6 +64,16 @@ def test_ill_shape(self): with self.assertRaisesRegex(AssertionError, ""): loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3))) + def test_with_cuda(self): + loss = ContrastiveLoss(temperature=0.5, batch_size=1) + i = torch.ones((1, 10)) + j = torch.ones((1, 10)) + if torch.cuda.is_available(): + i = i.cuda() + j = j.cuda() + output = loss(i, j) + np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + if __name__ == "__main__": unittest.main() From f13f6be582982cf2d3837c6a82277087546626c3 Mon Sep 17 00:00:00 2001 From: vnath Date: Thu, 11 Nov 2021 15:38:20 -0600 Subject: [PATCH 6/7] Fix for tests failing on PT16 due to device mismatch Signed-off-by: vnath --- monai/losses/contrastive.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index b091fd76e9..d2916f024d 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -66,23 +66,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + temperature_tensor = torch.tensor(self.temperature).to(input.device) + norm_i = F.normalize(input, dim=1) norm_j = F.normalize(target, dim=1) negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) - negatives_mask = torch.clone(torch.as_tensor(negatives_mask)).to(input.device) negatives_mask = torch.tensor(negatives_mask, dtype=torch.float) + negatives_mask = torch.clone(torch.as_tensor(negatives_mask)).to(input.device) repr = torch.cat([norm_i, norm_j], dim=0) sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) - sim_ij = torch.diag(sim_matrix, self.batch_size) sim_ji = torch.diag(sim_matrix, -self.batch_size) positives = torch.cat([sim_ij, sim_ji], dim=0) - - nominator = torch.exp(positives / self.temperature) - denominator = negatives_mask * torch.exp(sim_matrix / self.temperature) + nominator = torch.exp(positives / temperature_tensor) + denominator = negatives_mask * torch.exp(sim_matrix / temperature_tensor) loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) From eae17a7df8d1995a2fe1efa29993835ce42f865b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Nov 2021 21:39:01 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/contrastive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index d2916f024d..22caf3fe7d 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -67,7 +67,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") temperature_tensor = torch.tensor(self.temperature).to(input.device) - + norm_i = F.normalize(input, dim=1) norm_j = F.normalize(target, dim=1)