diff --git a/monai/networks/blocks/aspp.py b/monai/networks/blocks/aspp.py index d995d64796..41ed39c359 100644 --- a/monai/networks/blocks/aspp.py +++ b/monai/networks/blocks/aspp.py @@ -9,14 +9,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence +from typing import Optional, Sequence, Tuple, Union import torch import torch.nn as nn from monai.networks.blocks.convolutions import Convolution from monai.networks.layers import same_padding -from monai.networks.layers.factories import Act, Conv, Norm +from monai.networks.layers.factories import Conv class SimpleASPP(nn.Module): @@ -37,8 +37,8 @@ def __init__( conv_out_channels: int, kernel_sizes: Sequence[int] = (1, 3, 3, 3), dilations: Sequence[int] = (1, 2, 4, 6), - norm_type=Norm.BATCH, - acti_type=Act.LEAKYRELU, + norm_type: Optional[Union[Tuple, str]] = "BATCH", + acti_type: Optional[Union[Tuple, str]] = "LEAKYRELU", ) -> None: """ Args: diff --git a/tests/test_simple_aspp.py b/tests/test_simple_aspp.py index 89ca589c51..fbc8cb37d1 100644 --- a/tests/test_simple_aspp.py +++ b/tests/test_simple_aspp.py @@ -19,12 +19,12 @@ TEST_CASES = [ [ # 32-channel 2D, batch 7 - {"spatial_dims": 2, "in_channels": 32, "conv_out_channels": 3}, + {"spatial_dims": 2, "in_channels": 32, "conv_out_channels": 3, "norm_type": ("batch", {"affine": False})}, (7, 32, 18, 20), (7, 12, 18, 20), ], [ # 4-channel 1D, batch 16 - {"spatial_dims": 1, "in_channels": 4, "conv_out_channels": 8}, + {"spatial_dims": 1, "in_channels": 4, "conv_out_channels": 8, "acti_type": ("PRELU", {"num_parameters": 32})}, (16, 4, 17), (16, 32, 17), ],