From 18549d2af9e0af7b0a2e411e72d6eca9adcc7ef7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Jul 2021 13:58:07 +0800 Subject: [PATCH 1/4] [DLMED] enhance for scalar tensor Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 2 +- tests/test_as_discrete.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8913a1a041..fab031b378 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -169,7 +169,7 @@ def __call__( _nclasses = self.n_classes if n_classes is None else n_classes if not isinstance(_nclasses, int): raise AssertionError("One of self.n_classes or n_classes must be an integer") - img = one_hot(img, num_classes=_nclasses, dim=0) + img = one_hot(img.unsqueeze(0) if img.ndim == 0 else img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 658a21efd6..71a47a15cc 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -37,9 +37,16 @@ (1, 2, 2), ] +TEST_CASE_4 = [ + {"argmax": False, "to_onehot": True, "n_classes": 3}, + torch.tensor(1), + torch.tensor([0., 1., 0.]), + (3,), +] + class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) torch.testing.assert_allclose(result, out) From ac4682dbdc63a8dc5962e9d1c35ddac91f4ab236 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 8 Jul 2021 06:04:04 +0000 Subject: [PATCH 2/4] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_as_discrete.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 71a47a15cc..ea806be139 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -40,7 +40,7 @@ TEST_CASE_4 = [ {"argmax": False, "to_onehot": True, "n_classes": 3}, torch.tensor(1), - torch.tensor([0., 1., 0.]), + torch.tensor([0.0, 1.0, 0.0]), (3,), ] From a6201b6ee7d3c7894f8e1ca0ffcde8c3de4c547b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Jul 2021 14:21:48 +0800 Subject: [PATCH 3/4] [DLMED] remove redundant logic Signed-off-by: Nic Ma --- monai/networks/utils.py | 5 +---- monai/transforms/post/array.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d85175ef7e..9d20d2a83b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -48,7 +48,7 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to `num_classes` from `1`. dtype: the data type of the output one_hot label. - dim: the dimension to be converted to `num_classes` channels from `1` channel. + dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number. Example: @@ -69,9 +69,6 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f print(out.shape) # torch.Size([2, 2, 2, 2, 2]) """ - if labels.dim() == 0: - # if no channel dim, add it - labels = labels.unsqueeze(0) # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index fab031b378..8913a1a041 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -169,7 +169,7 @@ def __call__( _nclasses = self.n_classes if n_classes is None else n_classes if not isinstance(_nclasses, int): raise AssertionError("One of self.n_classes or n_classes must be an integer") - img = one_hot(img.unsqueeze(0) if img.ndim == 0 else img, num_classes=_nclasses, dim=0) + img = one_hot(img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) From 53cddd7ea2ca27a3dab310ea814df93f30ca15d7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 8 Jul 2021 14:25:40 +0800 Subject: [PATCH 4/4] [DLMED] add doc-string for img Signed-off-by: Nic Ma --- monai/transforms/post/array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8913a1a041..397b14e2e2 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -150,6 +150,8 @@ def __call__( ) -> torch.Tensor: """ Args: + img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, + will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: whether to convert input data into the one-hot format.