diff --git a/monai/networks/utils.py b/monai/networks/utils.py index bd25e358f6..c5989f174b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -71,9 +71,7 @@ def slice_channels(tensor: torch.Tensor, *slicevals: Optional[int]) -> torch.Ten return tensor[slices] -def predict_segmentation( - logits: torch.Tensor, mutually_exclusive: bool = False, threshold: float = 0.0 -) -> torch.Tensor: +def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False, threshold: float = 0.0) -> Any: """ Given the logits from a network, computing the segmentation by thresholding all values above 0 if multi-labels task, computing the `argmax` along the channel axis if multi-classes task,