diff --git a/CHANGELOG.md b/CHANGELOG.md index b0b088ee4..f341596d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,7 @@ gets uploaded to AzureML, by skipping all test folders. - ([#566](https://github.com/microsoft/InnerEye-DeepLearning/pull/566)) Update `hi-ml` dependency to `hi-ml-azure`. - ([#591](https://github.com/microsoft/InnerEye-DeepLearning/pull/591)) Upgrade Pytorch Lightning to 1.5.0 - ([#572](https://github.com/microsoft/InnerEye-DeepLearning/pull/572)) Updated to new version of hi-ml package +- ([#623](https://github.com/microsoft/InnerEye-DeepLearning/pull/623)) Save checkpoints in SSLOnlineEvaluator without DDP wrapper code - ([#617](https://github.com/microsoft/InnerEye-DeepLearning/pull/617)) Provide an easier way for LightningContainers to add callbacks. - ([#596](https://github.com/microsoft/InnerEye-DeepLearning/pull/596)) Add `cudatoolkit=11.1` specification to environment.yml. - ([#615](https://github.com/microsoft/InnerEye-DeepLearning/pull/615)) Minor changes to checkpoint download from AzureML. diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 8b5079592..1cdde6beb 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -3,6 +3,8 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, Union + import pytorch_lightning as pl import torch from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator @@ -12,7 +14,6 @@ from torch.nn import SyncBatchNorm, functional as F from torch.nn.parallel import DistributedDataParallel from torchmetrics import Metric -from typing import Any, Dict, List, Optional, Set, Tuple, Union from InnerEye.ML.SSL.utils import SSLDataModuleType, add_submodules_to_same_device from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve @@ -49,30 +50,36 @@ def __init__(self, Accuracy05()] \ if self.num_classes == 2 else [Accuracy05()] self.class_weights = class_weights - self.evaluator = SSLEvaluator(n_input=self.z_dim, - n_classes=self.num_classes, - p=self.drop_p, - n_hidden=self.hidden_dim) - self.optimizer = torch.optim.Adam(self.evaluator.parameters(), - lr=self.learning_rate, - weight_decay=self.weight_decay) + self.evaluator_state: Optional[OrderedDict] = None + self.optimizer_state: Optional[OrderedDict] = None + + def _wrapped_evaluator(self) -> torch.nn.Module: + """ + Gets the evaluator model that is wrapped in DDP, or the evaluator model itself. + """ + if isinstance(self.evaluator, DistributedDataParallel): + return self.evaluator.module + else: + return self.evaluator def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]) -> Dict[str, Any]: # Each callback gets its own state dictionary, that are fed back in during load + # When saving the evaluator, use the wrapped DDP module (otherwise the resulting checkpoint will depend + # on use of DDP or not). return { self.OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), - self.EVALUATOR_STATE_NAME: self.evaluator.state_dict() + self.EVALUATOR_STATE_NAME: self._wrapped_evaluator().state_dict() } def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, callback_state: Dict[str, Any]) -> None: - self.optimizer.load_state_dict(callback_state[self.OPTIMIZER_STATE_NAME]) - self.evaluator.load_state_dict(callback_state[self.EVALUATOR_STATE_NAME]) + self.optimizer_state = callback_state[self.OPTIMIZER_STATE_NAME] + self.evaluator_state = callback_state[self.EVALUATOR_STATE_NAME] def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ @@ -82,6 +89,10 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning """ for prefix, metrics in [("train", self.train_metrics), ("val", self.val_metrics)]: add_submodules_to_same_device(pl_module, metrics, prefix=prefix) + self.evaluator = SSLEvaluator(n_input=self.z_dim, + n_classes=self.num_classes, + p=self.drop_p, + n_hidden=self.hidden_dim) self.evaluator.to(pl_module.device) if hasattr(trainer, "accelerator_connector"): # This works with Lightning 1.3.8 @@ -98,6 +109,13 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning else: rank_zero_warn("This type of distributed accelerator is not supported. " "The online evaluator will not synchronize across GPUs.") + self.optimizer = torch.optim.Adam(self.evaluator.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay) + if self.evaluator_state is not None: + self._wrapped_evaluator().load_state_dict(self.evaluator_state) + if self.optimizer_state is not None: + self.optimizer.load_state_dict(self.optimizer_state) @staticmethod def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]: diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 111a84f8f..86d19aa18 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -321,9 +321,6 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No dataset="foo", drop_p=0.2, learning_rate=1e-5) - # Ensure that the parameters are really different initially - parameters2_before_training = list(callback2.evaluator.parameters()) - assert not torch.allclose(parameters2_before_training[0], parameters1[0]) # Start a second training run with recovery last_checkpoint = checkpoints.last_model_path trainer2 = Trainer(default_root_dir=str(test_output_dirs.root_dir), @@ -364,19 +361,24 @@ def test_online_evaluator_not_distributed() -> None: # Test the flag that the internal logic of on_pretrain_routine_start uses assert hasattr(trainer, "_accelerator_connector") assert not trainer._accelerator_connector.is_distributed - mock_module = mock.MagicMock(device=torch.device("cpu")) - callback.on_pretrain_routine_start(trainer, mock_module) + cpu = torch.device("cpu") + callback.on_pretrain_routine_start(trainer, mock.MagicMock(device=cpu)) assert isinstance(callback.evaluator, Module) mock_ddp.assert_not_called() + # Check that the evaluator is on the GPU before making any changes + assert list(callback.evaluator.parameters())[0].device == cpu + # Check that the evaluator is really moved to the right device + gpu0 = torch.device("cuda:0") + callback.on_pretrain_routine_start(trainer, mock.MagicMock(device=gpu0)) + assert list(callback.evaluator.parameters())[0].device == gpu0 -@pytest.mark.gpu def test_online_evaluator_distributed() -> None: """ Check if the online evaluator uses the DDP flag correctly when running distributed. """ - mock_ddp_result = "mock_ddp_result" - mock_sync_result = "mock_sync_result" + mock_ddp_result = torch.nn.Linear(in_features=10, out_features=1) + mock_sync_result = torch.nn.Linear(in_features=20, out_features=2) with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SyncBatchNorm.convert_sync_batchnorm", return_value=mock_sync_result) as mock_sync: with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel", @@ -389,16 +391,15 @@ def test_online_evaluator_distributed() -> None: learning_rate=1e-5) # Trainer with DDP - device = torch.device("cuda:0") + device = torch.device("cpu") mock_module = mock.MagicMock(device=device) - trainer = Trainer(accelerator="ddp", gpus=2) + trainer = Trainer(strategy="ddp", num_processes=2) # Test the two flags that the internal logic of on_pretrain_routine_start uses assert trainer._accelerator_connector.is_distributed assert trainer._accelerator_connector.use_ddp - original_evaluator = callback.evaluator callback.on_pretrain_routine_start(trainer, mock_module) # Check that SyncBatchNorm has been turned on - mock_sync.assert_called_once_with(original_evaluator) + mock_sync.assert_called_once() # Check that the evaluator has been turned into a DDP object # We still need to mock DDP here because the constructor relies on having a process group available mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device])