From 99c3a6d7a36fe847272b70e5b27bb5f403fccb5c Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 30 Sep 2022 13:59:48 +0800 Subject: [PATCH] fix #5225 Signed-off-by: KumoLiu --- monai/networks/layers/convutils.py | 2 +- monai/networks/nets/varautoencoder.py | 5 ++++- tests/test_varautoencoder.py | 29 ++++++++++++++++++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py index 1e9ce954e8..fe688b24ff 100644 --- a/monai/networks/layers/convutils.py +++ b/monai/networks/layers/convutils.py @@ -74,7 +74,7 @@ def calculate_out_shape( out_shape_np = ((in_shape_np - kernel_size_np + padding_np + padding_np) // stride_np) + 1 out_shape = tuple(int(s) for s in out_shape_np) - return out_shape if len(out_shape) > 1 else out_shape[0] + return out_shape def gaussian_1d( diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py index 7c6928afc0..31c2a5cfe6 100644 --- a/monai/networks/nets/varautoencoder.py +++ b/monai/networks/nets/varautoencoder.py @@ -48,6 +48,7 @@ class VarAutoEncoder(AutoEncoder): bias: whether to have a bias term in convolution blocks. Defaults to True. According to `Performance Tuning Guide `_, if a conv layer is directly followed by a batch norm layer, bias should be False. + use_sigmoid: whether to use the sigmoid function on final output. Defaults to True. Examples:: @@ -86,9 +87,11 @@ def __init__( norm: Union[Tuple, str] = Norm.INSTANCE, dropout: Optional[Union[Tuple, str, float]] = None, bias: bool = True, + use_sigmoid: bool = True, ) -> None: self.in_channels, *self.in_shape = in_shape + self.use_sigmoid = use_sigmoid self.latent_size = latent_size self.final_size = np.asarray(self.in_shape, dtype=int) @@ -148,4 +151,4 @@ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode_forward(x) z = self.reparameterize(mu, logvar) - return self.decode_forward(z), mu, logvar, z + return self.decode_forward(z, self.use_sigmoid), mu, logvar, z diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py index 04fc07f53f..a6315ebc63 100644 --- a/tests/test_varautoencoder.py +++ b/tests/test_varautoencoder.py @@ -75,7 +75,34 @@ (1, 3, 128, 128, 128), ] -CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] +TEST_CASE_4 = [ # 4-channel 1D, batch 4 + { + "spatial_dims": 1, + "in_shape": (4, 128), + "out_channels": 3, + "latent_size": 2, + "channels": (4, 8, 16), + "strides": (2, 2, 2), + }, + (1, 4, 128), + (1, 3, 128), +] + +TEST_CASE_5 = [ # 4-channel 1D, batch 4, use_sigmoid = False + { + "spatial_dims": 1, + "in_shape": (4, 128), + "out_channels": 3, + "latent_size": 2, + "channels": (4, 8, 16), + "strides": (2, 2, 2), + "use_sigmoid": False, + }, + (1, 4, 128), + (1, 3, 128), +] + +CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5] class TestVarAutoEncoder(unittest.TestCase):