diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index 9c4bfcf6ee..c52ab07a04 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -34,6 +34,10 @@ def generate_param_groups( layer_matches: a list of callable functions to select or filter out network layer groups, for "select" type, the input will be the `network`, for "filter" type, the input will be every item of `network.named_parameters()`. + for "select", the parameters will be + `select_func(network).parameters()`. + for "filter", the parameters will be + `map(lambda x: x[1], filter(filter_func, network.named_parameters()))` match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions, can be "select" or "filter". lr_values: a list of LR values corresponding to the `layer_matches` functions. @@ -48,7 +52,7 @@ def generate_param_groups( print(net.named_parameters()) # print out all the named parameters to filter out expected items params = generate_param_groups( network=net, - layer_matches=[lambda x: x.model[-1], lambda x: "conv.weight" in x], + layer_matches=[lambda x: x.model[0], lambda x: "2.0.conv" in x[0]], match_types=["select", "filter"], lr_values=[1e-2, 1e-3], ) @@ -71,7 +75,8 @@ def _select(): def _get_filter(f): def _filter(): - return filter(f, network.named_parameters()) + # should eventually generate a list of network parameters + return map(lambda x: x[1], filter(f, network.named_parameters())) return _filter diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index 8ccb8b7977..ea1fad44f9 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -25,6 +25,7 @@ "lr_values": [1], }, (1, 100), + [5, 21], ] TEST_CASE_2 = [ @@ -34,6 +35,7 @@ "lr_values": [1, 2, 3], }, (1, 2, 3, 100), + [5, 16, 5, 0], ] TEST_CASE_3 = [ @@ -43,15 +45,17 @@ "lr_values": [1], }, (1, 100), + [2, 24], ] TEST_CASE_4 = [ { - "layer_matches": [lambda x: x.model[-1], lambda x: "conv.weight" in x], + "layer_matches": [lambda x: x.model[0], lambda x: "2.0.conv" in x[0]], "match_types": ["select", "filter"], "lr_values": [1, 2], }, (1, 2, 100), + [5, 4, 17], ] TEST_CASE_5 = [ @@ -62,12 +66,24 @@ "include_others": False, }, (1), + [5], +] + +TEST_CASE_6 = [ + { + "layer_matches": [lambda x: "weight" in x[0]], + "match_types": ["filter"], + "lr_values": [1], + "include_others": True, + }, + (1), + [16, 10], ] class TestGenerateParamGroups(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_lr_values(self, input_param, expected_values): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_lr_values(self, input_param, expected_values, expected_groups): device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( dimensions=3, @@ -85,7 +101,7 @@ def test_lr_values(self, input_param, expected_values): torch.testing.assert_allclose(param_group["lr"], value) n = [len(p["params"]) for p in params] - assert sum(n) == 26 or all(n), "should have either full model or non-empty subsets." + self.assertListEqual(n, expected_groups) def test_wrong(self): """overlapped"""