diff --git a/CHANGELOG.md b/CHANGELOG.md index 89c9bc11f..7842737d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,7 @@ in inference-only runs when using lightning containers. - ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs. - ([#604](https://github.com/microsoft/InnerEye-DeepLearning/pull/604)) Fix issue where runs on a VM would download the dataset even when a local dataset is provided. +- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training ### Removed diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 16eb1161b..d39e1b14f 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -9,21 +9,24 @@ import torch from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.evaluator import SSLEvaluator +from pytorch_lightning.utilities import rank_zero_warn from torch import Tensor as T -from health_ml.utils import log_on_epoch -from torch.nn import functional as F +from torch.nn import SyncBatchNorm, functional as F +from torch.nn.parallel import DistributedDataParallel from torchmetrics import Metric from InnerEye.ML.SSL.utils import SSLDataModuleType from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve +from InnerEye.ML.utils.layer_util import set_model_to_eval_mode +from health_ml.utils import log_on_epoch BatchType = Union[Dict[SSLDataModuleType, Any], Any] -OPTIMIZER_STATE_NAME = "evaluator_optimizer" -EVALUATOR_STATE_NAME = "evaluator_weights" - class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator): + OPTIMIZER_STATE_NAME = "evaluator_optimizer" + EVALUATOR_STATE_NAME = "evaluator_weights" + def __init__(self, learning_rate: float, class_weights: Optional[torch.Tensor] = None, @@ -47,11 +50,11 @@ def __init__(self, Accuracy05()] \ if self.num_classes == 2 else [Accuracy05()] self.class_weights = class_weights - self.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim, - n_classes=self.num_classes, - p=self.drop_p, - n_hidden=self.hidden_dim) - self.optimizer = torch.optim.Adam(self.non_linear_evaluator.parameters(), + self.evaluator = SSLEvaluator(n_input=self.z_dim, + n_classes=self.num_classes, + p=self.drop_p, + n_hidden=self.hidden_dim) + self.optimizer = torch.optim.Adam(self.evaluator.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) @@ -61,24 +64,34 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: # Each callback gets its own state dictionary, that are fed back in during load return { - OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), - EVALUATOR_STATE_NAME: self.non_linear_evaluator.state_dict() + self.OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), + self.EVALUATOR_STATE_NAME: self.evaluator.state_dict() } def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, callback_state: Dict[str, Any]) -> None: - self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME]) - self.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME]) + self.optimizer.load_state_dict(callback_state[self.OPTIMIZER_STATE_NAME]) + self.evaluator.load_state_dict(callback_state[self.EVALUATOR_STATE_NAME]) def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ - Initializes modules and moves metrics and class weights to module device + Moves metrics and the online evaluator to the correct GPU. + If training happens via DDP, SyncBatchNorm is enabled for the online evaluator, and it is converted to + a DDP module. """ for metric in [*self.train_metrics, *self.val_metrics]: metric.to(device=pl_module.device) # type: ignore - self.non_linear_evaluator.to(pl_module.device) + self.evaluator.to(pl_module.device) + accelerator = trainer.accelerator_connector + if accelerator.is_distributed: + if accelerator.use_ddp: + self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator) + self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore + else: + rank_zero_warn("This type of distributed accelerator is not supported. " + "The online evaluator will not synchronize across GPUs.") @staticmethod def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]: @@ -108,7 +121,7 @@ def shared_step(self, batch: BatchType, pl_module: pl.LightningModule, is_traini representations = representations.detach() # Run the linear-head with SSL embeddings. - mlp_preds = self.non_linear_evaluator(representations) + mlp_preds = self.evaluator(representations) weights = None if self.class_weights is None else self.class_weights.to(device=pl_module.device) mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights) @@ -133,15 +146,11 @@ def on_validation_batch_end(self, trainer: pl.Trainer, ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist()) if ids_linear_head not in self.visited_ids: self.visited_ids.add(ids_linear_head) - # Put the online evaluator into "eval" mode - old_mode = self.non_linear_evaluator.training - self.non_linear_evaluator.eval() - loss = self.shared_step(batch, pl_module, is_training=False) - log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss) - for metric in self.val_metrics: - log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric) - # Put the online evaluator back into the state (eval or train) that it was before calling this method - self.non_linear_evaluator.train(old_mode) + with set_model_to_eval_mode(self.evaluator): + loss = self.shared_step(batch, pl_module, is_training=False) + log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss) + for metric in self.val_metrics: + log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore """ diff --git a/InnerEye/ML/utils/layer_util.py b/InnerEye/ML/utils/layer_util.py index f8463019b..e300e838d 100644 --- a/InnerEye/ML/utils/layer_util.py +++ b/InnerEye/ML/utils/layer_util.py @@ -2,7 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Iterable, Sized, Tuple, Union +from contextlib import contextmanager +from typing import Generator, Iterable, Sized, Tuple, Union import torch from torch.nn import init @@ -90,3 +91,16 @@ def upsample_size(down: int) -> int: upsample_size(downsampling_factor[1]), # type: ignore upsample_size(downsampling_factor[2])) # type: ignore return upsampling_kernel_size + + +@contextmanager +def set_model_to_eval_mode(model: torch.nn.Module) -> Generator: + """ + Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to + what is was before the call. + :param model: The model to modify. + """ + old_mode = model.training + model.eval() + yield + model.train(old_mode) diff --git a/InnerEye/ML/visualizers/model_summary.py b/InnerEye/ML/visualizers/model_summary.py index 2d3e87e4b..f0b4009ea 100644 --- a/InnerEye/ML/visualizers/model_summary.py +++ b/InnerEye/ML/visualizers/model_summary.py @@ -17,6 +17,7 @@ from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH from InnerEye.ML.utils.device_aware_module import DeviceAwareModule from InnerEye.ML.utils.ml_util import RandomStateSnapshot +from InnerEye.ML.utils.layer_util import set_model_to_eval_mode @dataclass @@ -217,15 +218,11 @@ def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor] inputs = [input_tensor.cuda() for input_tensor in inputs] # collect the current state of the model - is_train = module.training module_state = RandomStateSnapshot.snapshot_random_state() - # set the model in evaluation mode and perform a forward pass - module.eval() - with torch.no_grad(): - output = module.forward(*inputs) - if is_train: - module.train() + with set_model_to_eval_mode(module): + with torch.no_grad(): + output = module.forward(*inputs) # restore the seed for torch and numpy module_state.restore_random_state() diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 0481df212..fb1ff4935 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -12,22 +12,26 @@ import pytest import torch from pl_bolts.models.self_supervised.resnets import ResNet +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.nn import Module from torch.optim.lr_scheduler import _LRScheduler from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import is_windows from InnerEye.Common.fixed_paths import repository_root_directory from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path +from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier -from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import EVALUATOR_STATE_NAME, OPTIMIZER_STATE_NAME, \ - SSLOnlineEvaluatorInnerEye +from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier from InnerEye.ML.runner import Runner +from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel from Tests.ML.utils.test_io_util import write_test_dicom path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset") @@ -133,8 +137,8 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: assert "callbacks" in checkpoint assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"] callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye] - assert OPTIMIZER_STATE_NAME in callback_state - assert EVALUATOR_STATE_NAME in callback_state + assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state + assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state # Now run the actual SSL classifier off the stored checkpoint args = common_test_args + ["--model=SSLClassifierCIFAR", f"--local_ssl_weights_path={checkpoint_path}"] @@ -268,3 +272,130 @@ def test_simclr_lr_scheduler() -> None: assert lr[i] < lr[i + 1], f"Not strictly monotonically increasing at index {i}" for i in range(highest_lr, len(lr) - 1): assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}" + + +def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None: + """ + Test checkpoint recovery for the online evaluator in an end-to-end training run. + """ + container = DummyContainerWithModel() + model = container.create_model() + data = container.get_data_module() + checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints") + checkpoint_folder.mkdir(exist_ok=True) + checkpoints = ModelCheckpoint(dirpath=checkpoint_folder, + every_n_val_epochs=1, + save_last=True) + # Create a first callback, that will be used in training. + callback1 = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + # To simplify the test setup, do not run any actual training (this would require complicated dataset with a + # combined loader) + with mock.patch( + "InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye.on_train_batch_end", + return_value=None) as mock_train: + with mock.patch( + "InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye" + ".on_validation_batch_end", + return_value=None): + trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir), + callbacks=[checkpoints, callback1], + max_epochs=10) + trainer.fit(model, datamodule=data) + # Check that the callback was actually used + mock_train.assert_called() + # Now read out the parameters of the callback. + # We will then run a second training job, with a new callback object, that will be initialized randomly, + # and should have different parameters initially. After checkpoint recovery, it should have exactly the + # same parameters as the first callback. + parameters1 = list(callback1.evaluator.parameters()) + callback2 = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + # Ensure that the parameters are really different initially + parameters2_before_training = list(callback2.evaluator.parameters()) + assert not torch.allclose(parameters2_before_training[0], parameters1[0]) + # Start a second training run with recovery + last_checkpoint = checkpoints.last_model_path + trainer2 = Trainer(default_root_dir=str(test_output_dirs.root_dir), + callbacks=[callback2], + max_epochs=20, + resume_from_checkpoint=last_checkpoint) + trainer2.fit(model, datamodule=data) + # Read the parameters and check if they are the same as what was stored in the first callback. + parameters2_after_training = list(callback2.evaluator.parameters()) + assert torch.allclose(parameters2_after_training[0], parameters1[0]) + + # It's somewhat obsolete, but we can now check that the checkpoint file really contained the optimizer and weights + checkpoint = torch.load(last_checkpoint) + assert "callbacks" in checkpoint + assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"] + callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye] + assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state + assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state + + +@pytest.mark.gpu +def test_online_evaluator_not_distributed() -> None: + """ + Check if the online evaluator uses the DDP flag correctly when running not distributed + """ + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp: + callback = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + mock_ddp.assert_not_called() + + # Standard trainer without DDP + trainer = Trainer() + # Test the flag that the internal logic of on_pretrain_routine_start uses + assert not trainer.accelerator_connector.is_distributed + mock_module = mock.MagicMock(device=torch.device("cpu")) + callback.on_pretrain_routine_start(trainer, mock_module) + assert isinstance(callback.evaluator, Module) + mock_ddp.assert_not_called() + + +@pytest.mark.gpu +def test_online_evaluator_distributed() -> None: + """ + Check if the online evaluator uses the DDP flag correctly when running distributed. + """ + mock_ddp_result = "mock_ddp_result" + mock_sync_result = "mock_sync_result" + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SyncBatchNorm.convert_sync_batchnorm", + return_value=mock_sync_result) as mock_sync: + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel", + return_value=mock_ddp_result) as mock_ddp: + callback = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + + # Trainer with DDP + device = torch.device("cuda:0") + mock_module = mock.MagicMock(device=device) + trainer = Trainer(accelerator="ddp", gpus=2) + # Test the two flags that the internal logic of on_pretrain_routine_start uses + assert trainer.accelerator_connector.is_distributed + assert trainer.accelerator_connector.use_ddp + original_evaluator = callback.evaluator + callback.on_pretrain_routine_start(trainer, mock_module) + # Check that SyncBatchNorm has been turned on + mock_sync.assert_called_once_with(original_evaluator) + # Check that the evaluator has been turned into a DDP object + # We still need to mock DDP here because the constructor relies on having a process group available + mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device]) + assert callback.evaluator == mock_ddp_result