Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f
num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to
`num_classes` from `1`.
dtype: the data type of the output one_hot label.
dim: the dimension to be converted to `num_classes` channels from `1` channel.
dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number.

Example:

Expand All @@ -69,9 +69,6 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f
print(out.shape) # torch.Size([2, 2, 2, 2, 2])

"""
if labels.dim() == 0:
# if no channel dim, add it
labels = labels.unsqueeze(0)

# if `dim` is bigger, add singleton dim at the end
if labels.ndim < dim + 1:
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def __call__(
) -> torch.Tensor:
"""
Args:
img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
will automatically add it.
argmax: whether to execute argmax function on input data before transform.
Defaults to ``self.argmax``.
to_onehot: whether to convert input data into the one-hot format.
Expand Down
9 changes: 8 additions & 1 deletion tests/test_as_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@
(1, 2, 2),
]

TEST_CASE_4 = [
{"argmax": False, "to_onehot": True, "n_classes": 3},
torch.tensor(1),
torch.tensor([0.0, 1.0, 0.0]),
(3,),
]


class TestAsDiscrete(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value_shape(self, input_param, img, out, expected_shape):
result = AsDiscrete(**input_param)(img)
torch.testing.assert_allclose(result, out)
Expand Down