diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 63d4cfa8..7a8662ef 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -20,73 +20,108 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -TEST_CASE_0 = [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_channels": 32, - "latent_channels": 8, - "ch_mult": [1, 1, 1], - "num_res_blocks": 1, - }, - (2, 1, 64, 64), - (2, 1, 64, 64), - (2, 8, 16, 16), +CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": 4, + "latent_channels": 4, + "ch_mult": [1, 1, 1], + "attention_levels": None, + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": 4, + "latent_channels": 4, + "ch_mult": [1, 1, 1], + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": 4, + "latent_channels": 4, + "ch_mult": [1, 1, 1], + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": 4, + "latent_channels": 4, + "ch_mult": [1, 1, 1], + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": 4, + "latent_channels": 4, + "ch_mult": [1, 1, 1], + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + "with_encoder_nonlocal_attn": False, + "with_decoder_nonlocal_attn": False, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_channels": 4, + "latent_channels": 4, + "ch_mult": [1, 1, 1], + "attention_levels": (False, False, True), + "num_res_blocks": 1, + "norm_num_groups": 4, + }, + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + (1, 4, 4, 4, 4), + ], ] -TEST_CASE_1 = [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_channels": 32, - "latent_channels": 32, - "ch_mult": [1, 1, 1, 1], - "num_res_blocks": 1, - }, - (2, 1, 64, 64), - (2, 1, 64, 64), - (2, 32, 8, 8), -] - -TEST_CASE_2 = [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_channels": 32, - "latent_channels": 32, - "ch_mult": [1, 1, 1, 1], - "num_res_blocks": 1, - "attention_levels": (False, False, False, True), - "with_encoder_nonlocal_attn": False, - }, - (2, 1, 64, 64), - (2, 1, 64, 64), - (2, 32, 8, 8), -] - -TEST_CASE_3 = [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "num_channels": 32, - "latent_channels": 32, - "ch_mult": [1, 1, 1, 1], - "num_res_blocks": 1, - "attention_levels": (True, True, True, True), - "with_encoder_nonlocal_attn": False, - "with_decoder_nonlocal_attn": False, - }, - (2, 1, 64, 64), - (2, 1, 64, 64), - (2, 32, 8, 8), -] - -CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3] - class TestAutoEncoderKL(unittest.TestCase): @parameterized.expand(CASES) @@ -131,13 +166,37 @@ def test_model_ch_mult_not_same_size_of_attention_levels(self): attention_levels=(True,), ) - @parameterized.expand(CASES) - def test_shape_reconstruction(self, input_param, input_shape, expected_shape, _): + def test_shape_reconstruction(self): + input_param, input_shape, expected_shape, _ = CASES[0] net = AutoencoderKL(**input_param).to(device) with eval_mode(net): result = net.reconstruct(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + def test_shape_encode(self): + input_param, input_shape, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.encode(torch.randn(input_shape).to(device)) + self.assertEqual(result[0].shape, expected_latent_shape) + self.assertEqual(result[1].shape, expected_latent_shape) + + def test_shape_sampling(self): + input_param, _, _, expected_latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.sampling( + torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device) + ) + self.assertEqual(result.shape, expected_latent_shape) + + def test_shape_decode(self): + input_param, expected_input_shape, _, latent_shape = CASES[0] + net = AutoencoderKL(**input_param).to(device) + with eval_mode(net): + result = net.decode(torch.randn(latent_shape).to(device)) + self.assertEqual(result.shape, expected_input_shape) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py index e516ae1f..db7f4ef5 100644 --- a/tests/test_vqvae.py +++ b/tests/test_vqvae.py @@ -10,7 +10,6 @@ # limitations under the License. import unittest -from itertools import product import torch from monai.networks import eval_mode @@ -19,28 +18,20 @@ from generative.networks.nets.vqvae import VQVAE -configurations = product( - [2, 4], # Number of downsamplings - [16, 64], # Embedding dimension - [1, 3], # Batch size - [1, 3], # Number of input channels - [64, 256], # Spatial input shape -) - -CASES_2D = [ +TEST_CASES = [ [ { "spatial_dims": 2, - "in_channels": in_channels, - "out_channels": in_channels, - "num_levels": no_levels, - "downsample_parameters": [(2, 4, 1, 1)] * no_levels, - "upsample_parameters": [(2, 4, 1, 1, 0)] * no_levels, + "in_channels": 1, + "out_channels": 1, + "num_levels": 2, + "downsample_parameters": [(2, 4, 1, 1)] * 2, + "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_res_layers": 1, - "num_channels": [4] * (no_levels - 1) + [8], - "num_res_channels": [4] * (no_levels - 1) + [8], - "num_embeddings": 256, - "embedding_dim": embedding_dim, + "num_channels": [8, 8], + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, "embedding_init": "normal", "commitment_cost": 0.25, "decay": 0.5, @@ -50,34 +41,22 @@ "act": "RELU", "output_act": None, }, - (batch_size, in_channels, spatial_input_shape, spatial_input_shape), - (batch_size, in_channels, spatial_input_shape, spatial_input_shape), - ] - for no_levels, embedding_dim, batch_size, in_channels, spatial_input_shape in configurations -] - -configurations = product( - [2, 4], # Number of downsamplings - [16, 64], # Embedding dimension - [1, 3], # Batch size - [1, 3], # Number of input channels - [64, 256], # Spatial input shape -) - -CASES_3D = [ + (1, 1, 16, 16), + (1, 1, 16, 16), + ], [ { "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": in_channels, - "num_levels": no_levels, - "downsample_parameters": [(2, 4, 1, 1)] * no_levels, - "upsample_parameters": [(2, 4, 1, 1, 0)] * no_levels, + "in_channels": 1, + "out_channels": 1, + "num_levels": 2, + "downsample_parameters": [(2, 4, 1, 1)] * 2, + "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_res_layers": 1, - "num_channels": [4] * (no_levels - 1) + [8], - "num_res_channels": [4] * (no_levels - 1) + [8], - "num_embeddings": 256, - "embedding_dim": embedding_dim, + "num_channels": [8, 8], + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, "embedding_init": "normal", "commitment_cost": 0.25, "decay": 0.5, @@ -87,14 +66,12 @@ "act": "RELU", "output_act": None, }, - (batch_size, in_channels, spatial_input_shape, spatial_input_shape, spatial_input_shape), - (batch_size, in_channels, spatial_input_shape, spatial_input_shape, spatial_input_shape), - ] - for no_levels, embedding_dim, batch_size, in_channels, spatial_input_shape in configurations + (1, 1, 16, 16, 16), + (1, 1, 16, 16, 16), + ], ] -# 1-channel 2D, should fail because of number of levels, number of downsamplings, number of upsamplings, num_channels -# and num_res_channels mismatch. +# 1-channel 2D, should fail because of number of levels, number of downsamplings, number of upsamplings mismatch. TEST_CASE_FAIL = { "spatial_dims": 3, "in_channels": 1, @@ -103,10 +80,10 @@ "downsample_parameters": [(2, 4, 1, 1)] * 2, "upsample_parameters": [(2, 4, 1, 1, 0)] * 4, "num_res_layers": 1, - "num_channels": [4] * 1 + [8], - "num_res_channels": [4] * 5 + [8], - "num_embeddings": 256, - "embedding_dim": 32, + "num_channels": [8, 8], + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, "embedding_init": "normal", "commitment_cost": 0.25, "decay": 0.5, @@ -121,14 +98,14 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_levels": 4, - "downsample_parameters": [(2, 4, 1, 1)] * 4, - "upsample_parameters": [(2, 4, 1, 1, 0)] * 4, + "num_levels": 2, + "downsample_parameters": [(2, 4, 1, 1)] * 2, + "upsample_parameters": [(2, 4, 1, 1, 0)] * 2, "num_res_layers": 1, - "num_channels": [4] * (4 - 1) + [8], - "num_res_channels": [4] * (4 - 1) + [8], - "num_embeddings": 256, - "embedding_dim": 32, + "num_channels": [8, 8], + "num_res_channels": [8, 8], + "num_embeddings": 16, + "embedding_dim": 8, "embedding_init": "normal", "commitment_cost": 0.25, "decay": 0.5, @@ -141,7 +118,7 @@ class TestVQVAE(unittest.TestCase): - @parameterized.expand(CASES_2D + CASES_3D) + @parameterized.expand(TEST_CASES) def test_shape(self, input_param, input_shape, expected_shape): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -157,14 +134,14 @@ def test_script(self): spatial_dims=2, in_channels=1, out_channels=1, - num_levels=4, - downsample_parameters=tuple([(2, 4, 1, 1)] * 4), - upsample_parameters=tuple([(2, 4, 1, 1, 0)] * 4), + num_levels=2, + downsample_parameters=tuple([(2, 4, 1, 1)] * 2), + upsample_parameters=tuple([(2, 4, 1, 1, 0)] * 2), num_res_layers=1, - num_channels=[4] * (4 - 1) + [8], - num_res_channels=[4] * (4 - 1) + [8], - num_embeddings=2048, - embedding_dim=32, + num_channels=[8, 8], + num_res_channels=[8, 8], + num_embeddings=16, + embedding_dim=8, embedding_init="normal", commitment_cost=0.25, decay=0.5, @@ -175,7 +152,7 @@ def test_script(self): output_act=None, ddp_sync=False, ) - test_data = torch.randn(2, 1, 256, 256) + test_data = torch.randn(1, 1, 16, 16) test_script_save(net, test_data) def test_level_upsample_downsample_difference(self): @@ -188,9 +165,9 @@ def test_encode_shape(self): net = VQVAE(**TEST_LATENT_SHAPE).to(device) with eval_mode(net): - latent = net.encode(torch.randn(2, 1, 256, 256).to(device)) + latent = net.encode(torch.randn(1, 1, 32, 32).to(device)) - self.assertEqual(latent.shape, (2, 32, 16, 16)) + self.assertEqual(latent.shape, (1, 8, 8, 8)) def test_index_quantize_shape(self): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -198,9 +175,9 @@ def test_index_quantize_shape(self): net = VQVAE(**TEST_LATENT_SHAPE).to(device) with eval_mode(net): - latent = net.index_quantize(torch.randn(2, 1, 256, 256).to(device)) + latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device)) - self.assertEqual(latent.shape, (2, 16, 16)) + self.assertEqual(latent.shape, (1, 8, 8)) def test_decode_shape(self): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -208,9 +185,9 @@ def test_decode_shape(self): net = VQVAE(**TEST_LATENT_SHAPE).to(device) with eval_mode(net): - latent = net.decode(torch.randn(2, 32, 16, 16).to(device)) + latent = net.decode(torch.randn(1, 8, 8, 8).to(device)) - self.assertEqual(latent.shape, (2, 1, 256, 256)) + self.assertEqual(latent.shape, (1, 1, 32, 32)) def test_decode_samples_shape(self): device = "cuda" if torch.cuda.is_available() else "cpu" @@ -218,9 +195,9 @@ def test_decode_samples_shape(self): net = VQVAE(**TEST_LATENT_SHAPE).to(device) with eval_mode(net): - latent = net.decode_samples(torch.randint(low=0, high=256, size=(2, 16, 16)).to(device)) + latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device)) - self.assertEqual(latent.shape, (2, 1, 256, 256)) + self.assertEqual(latent.shape, (1, 1, 32, 32)) if __name__ == "__main__":