diff --git a/monai/networks/utils.py b/monai/networks/utils.py index d85175ef7e..9d20d2a83b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -48,7 +48,7 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to `num_classes` from `1`. dtype: the data type of the output one_hot label. - dim: the dimension to be converted to `num_classes` channels from `1` channel. + dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number. Example: @@ -69,9 +69,6 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f print(out.shape) # torch.Size([2, 2, 2, 2, 2]) """ - if labels.dim() == 0: - # if no channel dim, add it - labels = labels.unsqueeze(0) # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 8913a1a041..397b14e2e2 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -150,6 +150,8 @@ def __call__( ) -> torch.Tensor: """ Args: + img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`, + will automatically add it. argmax: whether to execute argmax function on input data before transform. Defaults to ``self.argmax``. to_onehot: whether to convert input data into the one-hot format. diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 658a21efd6..ea806be139 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -37,9 +37,16 @@ (1, 2, 2), ] +TEST_CASE_4 = [ + {"argmax": False, "to_onehot": True, "n_classes": 3}, + torch.tensor(1), + torch.tensor([0.0, 1.0, 0.0]), + (3,), +] + class TestAsDiscrete(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_value_shape(self, input_param, img, out, expected_shape): result = AsDiscrete(**input_param)(img) torch.testing.assert_allclose(result, out)