From a2fa635fb1bbf036c149ff823b85a7df82d7436c Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 8 Dec 2021 21:54:09 +0000 Subject: [PATCH 01/14] hi-ml main --- hi-ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hi-ml b/hi-ml index 334186321..38fd68557 160000 --- a/hi-ml +++ b/hi-ml @@ -1 +1 @@ -Subproject commit 334186321f6989033f5609880781ba4c299f6f67 +Subproject commit 38fd685579749e6c5e5f8c199c76c5854394421c From 35382f39cf5098d57132baf3e6b94d7e67c1799f Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 8 Dec 2021 22:13:46 +0000 Subject: [PATCH 02/14] context manager --- InnerEye/ML/utils/model_util.py | 16 +++++++++++++++- InnerEye/ML/visualizers/model_summary.py | 11 ++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/InnerEye/ML/utils/model_util.py b/InnerEye/ML/utils/model_util.py index d182f3432..10baaf446 100644 --- a/InnerEye/ML/utils/model_util.py +++ b/InnerEye/ML/utils/model_util.py @@ -3,8 +3,9 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import logging +from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union +from typing import Any, Dict, Generator, Generic, Iterator, List, Optional, TypeVar, Union import torch from torch.nn import MSELoss @@ -321,3 +322,16 @@ def get_scalar_model_inputs_and_labels(model: torch.nn.Module, subject_ids=subject_ids, data_item=scalar_item ) + + +@contextmanager +def set_model_to_eval_mode(model: torch.nn.Module) -> Generator: + """ + Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to + what is was before the call. + :param model: The model to modify. + """ + old_mode = model.training + model.eval() + yield + model.train(old_mode) diff --git a/InnerEye/ML/visualizers/model_summary.py b/InnerEye/ML/visualizers/model_summary.py index 2d3e87e4b..12acfc1ca 100644 --- a/InnerEye/ML/visualizers/model_summary.py +++ b/InnerEye/ML/visualizers/model_summary.py @@ -17,6 +17,7 @@ from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH from InnerEye.ML.utils.device_aware_module import DeviceAwareModule from InnerEye.ML.utils.ml_util import RandomStateSnapshot +from InnerEye.ML.utils.model_util import set_model_to_eval_mode @dataclass @@ -217,15 +218,11 @@ def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor] inputs = [input_tensor.cuda() for input_tensor in inputs] # collect the current state of the model - is_train = module.training module_state = RandomStateSnapshot.snapshot_random_state() - # set the model in evaluation mode and perform a forward pass - module.eval() - with torch.no_grad(): - output = module.forward(*inputs) - if is_train: - module.train() + with set_model_to_eval_mode(module): + with torch.no_grad(): + output = module.forward(*inputs) # restore the seed for torch and numpy module_state.restore_random_state() From d95b45da91b4604b3b6ba95b48ed3c37e3cf1a7e Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 8 Dec 2021 22:14:51 +0000 Subject: [PATCH 03/14] test for checkpoints --- Tests/SSL/test_ssl_containers.py | 81 ++++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 0481df212..955991323 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -12,22 +12,25 @@ import pytest import torch from pl_bolts.models.self_supervised.resnets import ResNet +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from torch.optim.lr_scheduler import _LRScheduler from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import is_windows from InnerEye.Common.fixed_paths import repository_root_directory from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path +from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier -from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import EVALUATOR_STATE_NAME, OPTIMIZER_STATE_NAME, \ - SSLOnlineEvaluatorInnerEye +from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier from InnerEye.ML.runner import Runner +from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel from Tests.ML.utils.test_io_util import write_test_dicom path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset") @@ -133,8 +136,8 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: assert "callbacks" in checkpoint assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"] callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye] - assert OPTIMIZER_STATE_NAME in callback_state - assert EVALUATOR_STATE_NAME in callback_state + assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state + assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state # Now run the actual SSL classifier off the stored checkpoint args = common_test_args + ["--model=SSLClassifierCIFAR", f"--local_ssl_weights_path={checkpoint_path}"] @@ -268,3 +271,73 @@ def test_simclr_lr_scheduler() -> None: assert lr[i] < lr[i + 1], f"Not strictly monotonically increasing at index {i}" for i in range(highest_lr, len(lr) - 1): assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}" + + +def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None: + """ + Test checkpoint recovery for the online evaluator in an end-to-end training run. + """ + container = DummyContainerWithModel() + model = container.create_model() + data = container.get_data_module() + checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints") + checkpoint_folder.mkdir(exist_ok=True) + checkpoints = ModelCheckpoint(dirpath=checkpoint_folder, + every_n_val_epochs=1, + save_last=True) + # Create a first callback, that will be used in training. + callback1 = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + # To simplify the test setup, do not run any actual training (this would require complicated dataset with a + # combined loader) + with mock.patch( + "InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye.on_train_batch_end", + return_value=None) as mock_train: + with mock.patch( + "InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye" + ".on_validation_batch_end", + return_value=None): + trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir), + callbacks=[checkpoints, callback1], + max_epochs=10) + trainer.fit(model, datamodule=data) + # Check that the callback was actually used + mock_train.assert_called() + # Now read out the parameters of the callback. + # We will then run a second training job, with a new callback object, that will be initialized randomly, + # and should have different parameters initially. After checkpoint recovery, it should have exactly the + # same parameters as the first callback. + parameters1 = list(callback1.evaluator.parameters()) + callback2 = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + # Ensure that the parameters are really different initially + parameters2_before_training = list(callback2.evaluator.parameters()) + assert not torch.allclose(parameters2_before_training[0], parameters1[0]) + # Start a second training run with recovery + last_checkpoint = checkpoints.last_model_path + trainer2 = Trainer(default_root_dir=str(test_output_dirs.root_dir), + callbacks=[callback2], + max_epochs=20, + resume_from_checkpoint=last_checkpoint) + trainer2.fit(model, datamodule=data) + # Read the parameters and check if they are the same as what was stored in the first callback. + parameters2_after_training = list(callback2.evaluator.parameters()) + assert torch.allclose(parameters2_after_training[0], parameters1[0]) + + # It's somewhat obsolete, but we can now check that the checkpoint file really contained the optimizer and weights + checkpoint = torch.load(last_checkpoint) + assert "callbacks" in checkpoint + assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"] + callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye] + assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state + assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state + + From fb62b659c2a5d14505c2182d3b3c5ba6273363bf Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 8 Dec 2021 22:14:58 +0000 Subject: [PATCH 04/14] DDP --- .../lightning_modules/ssl_online_evaluator.py | 58 +++++++++++-------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 16eb1161b..292e0cce8 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -9,21 +9,24 @@ import torch from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator from pl_bolts.models.self_supervised.evaluator import SSLEvaluator +from pytorch_lightning.utilities import rank_zero_warn from torch import Tensor as T -from health_ml.utils import log_on_epoch -from torch.nn import functional as F +from torch.nn import DataParallel, functional as F +from torch.nn.parallel import DistributedDataParallel from torchmetrics import Metric from InnerEye.ML.SSL.utils import SSLDataModuleType from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve +from InnerEye.ML.utils.model_util import set_model_to_eval_mode +from health_ml.utils import log_on_epoch BatchType = Union[Dict[SSLDataModuleType, Any], Any] -OPTIMIZER_STATE_NAME = "evaluator_optimizer" -EVALUATOR_STATE_NAME = "evaluator_weights" - class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator): + OPTIMIZER_STATE_NAME = "evaluator_optimizer" + EVALUATOR_STATE_NAME = "evaluator_weights" + def __init__(self, learning_rate: float, class_weights: Optional[torch.Tensor] = None, @@ -47,11 +50,11 @@ def __init__(self, Accuracy05()] \ if self.num_classes == 2 else [Accuracy05()] self.class_weights = class_weights - self.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim, - n_classes=self.num_classes, - p=self.drop_p, - n_hidden=self.hidden_dim) - self.optimizer = torch.optim.Adam(self.non_linear_evaluator.parameters(), + self.evaluator = SSLEvaluator(n_input=self.z_dim, + n_classes=self.num_classes, + p=self.drop_p, + n_hidden=self.hidden_dim) + self.optimizer = torch.optim.Adam(self.evaluator.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) @@ -61,16 +64,16 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: # Each callback gets its own state dictionary, that are fed back in during load return { - OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), - EVALUATOR_STATE_NAME: self.non_linear_evaluator.state_dict() + self.OPTIMIZER_STATE_NAME: self.optimizer.state_dict(), + self.EVALUATOR_STATE_NAME: self.evaluator.state_dict() } def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, callback_state: Dict[str, Any]) -> None: - self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME]) - self.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME]) + self.optimizer.load_state_dict(callback_state[self.OPTIMIZER_STATE_NAME]) + self.evaluator.load_state_dict(callback_state[self.EVALUATOR_STATE_NAME]) def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ @@ -78,7 +81,16 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning """ for metric in [*self.train_metrics, *self.val_metrics]: metric.to(device=pl_module.device) # type: ignore - self.non_linear_evaluator.to(pl_module.device) + self.evaluator.to(pl_module.device) + accelerator = trainer.accelerator_connector + if accelerator.is_distributed: + if accelerator.use_ddp: + self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) + elif accelerator.use_dp: + self.evaluator = DataParallel(self.evaluator, device_ids=[pl_module.device]) + else: + rank_zero_warn("This type of distributed accelerator is not supported. " + "The online evaluator will not synchronize across GPUs.") @staticmethod def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]: @@ -108,7 +120,7 @@ def shared_step(self, batch: BatchType, pl_module: pl.LightningModule, is_traini representations = representations.detach() # Run the linear-head with SSL embeddings. - mlp_preds = self.non_linear_evaluator(representations) + mlp_preds = self.evaluator(representations) weights = None if self.class_weights is None else self.class_weights.to(device=pl_module.device) mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights) @@ -133,15 +145,11 @@ def on_validation_batch_end(self, trainer: pl.Trainer, ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist()) if ids_linear_head not in self.visited_ids: self.visited_ids.add(ids_linear_head) - # Put the online evaluator into "eval" mode - old_mode = self.non_linear_evaluator.training - self.non_linear_evaluator.eval() - loss = self.shared_step(batch, pl_module, is_training=False) - log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss) - for metric in self.val_metrics: - log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric) - # Put the online evaluator back into the state (eval or train) that it was before calling this method - self.non_linear_evaluator.train(old_mode) + with set_model_to_eval_mode(self.evaluator): + loss = self.shared_step(batch, pl_module, is_training=False) + log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss) + for metric in self.val_metrics: + log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore """ From 58d893ac48ee76551907398332207ff47431cb70 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 8 Dec 2021 22:39:05 +0000 Subject: [PATCH 05/14] tests moving context manager to avoid circular imports --- .../lightning_modules/ssl_online_evaluator.py | 2 +- InnerEye/ML/utils/layer_util.py | 16 ++++++++++++- InnerEye/ML/utils/model_util.py | 16 +------------ InnerEye/ML/visualizers/model_summary.py | 2 +- Tests/SSL/test_ssl_containers.py | 24 +++++++++++++++++++ 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 292e0cce8..72c6f9f6f 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -17,7 +17,7 @@ from InnerEye.ML.SSL.utils import SSLDataModuleType from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve -from InnerEye.ML.utils.model_util import set_model_to_eval_mode +from InnerEye.ML.utils.layer_util import set_model_to_eval_mode from health_ml.utils import log_on_epoch BatchType = Union[Dict[SSLDataModuleType, Any], Any] diff --git a/InnerEye/ML/utils/layer_util.py b/InnerEye/ML/utils/layer_util.py index f8463019b..e300e838d 100644 --- a/InnerEye/ML/utils/layer_util.py +++ b/InnerEye/ML/utils/layer_util.py @@ -2,7 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from typing import Iterable, Sized, Tuple, Union +from contextlib import contextmanager +from typing import Generator, Iterable, Sized, Tuple, Union import torch from torch.nn import init @@ -90,3 +91,16 @@ def upsample_size(down: int) -> int: upsample_size(downsampling_factor[1]), # type: ignore upsample_size(downsampling_factor[2])) # type: ignore return upsampling_kernel_size + + +@contextmanager +def set_model_to_eval_mode(model: torch.nn.Module) -> Generator: + """ + Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to + what is was before the call. + :param model: The model to modify. + """ + old_mode = model.training + model.eval() + yield + model.train(old_mode) diff --git a/InnerEye/ML/utils/model_util.py b/InnerEye/ML/utils/model_util.py index 10baaf446..d182f3432 100644 --- a/InnerEye/ML/utils/model_util.py +++ b/InnerEye/ML/utils/model_util.py @@ -3,9 +3,8 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import logging -from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Generator, Generic, Iterator, List, Optional, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union import torch from torch.nn import MSELoss @@ -322,16 +321,3 @@ def get_scalar_model_inputs_and_labels(model: torch.nn.Module, subject_ids=subject_ids, data_item=scalar_item ) - - -@contextmanager -def set_model_to_eval_mode(model: torch.nn.Module) -> Generator: - """ - Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to - what is was before the call. - :param model: The model to modify. - """ - old_mode = model.training - model.eval() - yield - model.train(old_mode) diff --git a/InnerEye/ML/visualizers/model_summary.py b/InnerEye/ML/visualizers/model_summary.py index 12acfc1ca..f0b4009ea 100644 --- a/InnerEye/ML/visualizers/model_summary.py +++ b/InnerEye/ML/visualizers/model_summary.py @@ -17,7 +17,7 @@ from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH from InnerEye.ML.utils.device_aware_module import DeviceAwareModule from InnerEye.ML.utils.ml_util import RandomStateSnapshot -from InnerEye.ML.utils.model_util import set_model_to_eval_mode +from InnerEye.ML.utils.layer_util import set_model_to_eval_mode @dataclass diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 955991323..5c2c148a6 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -14,6 +14,8 @@ from pl_bolts.models.self_supervised.resnets import ResNet from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import _LRScheduler from InnerEye.Common import fixed_paths @@ -341,3 +343,25 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state +@pytest.mark.gpu +def test_online_evaluator_distributed() -> None: + """ + A very primitive type of test to check if the online evaluator uses the DDP flag correctly. + """ + callback = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + assert isinstance(callback.evaluator, Module) + assert not isinstance(callback.evaluator, DistributedDataParallel) + trainer = Trainer() + mock_module = mock.MagicMock(device=torch.device("cpu")) + callback.on_pretrain_routine_start(trainer, mock_module) + assert isinstance(callback.evaluator, Module) + assert not isinstance(callback.evaluator, DistributedDataParallel) + mock_module = mock.MagicMock(device=torch.device("cuda:0")) + trainer = Trainer(accelerator="ddp", gpus=2) + callback.on_pretrain_routine_start(trainer, mock_module) + assert isinstance(callback.evaluator, DistributedDataParallel) From d1f38a5fc7079e780dd83676362157a052ada79f Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 8 Dec 2021 22:41:44 +0000 Subject: [PATCH 06/14] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74c6d9784..ed031b261 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,7 @@ in inference-only runs when using lightning containers. - ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs. - ([#604](https://github.com/microsoft/InnerEye-DeepLearning/pull/604)) Fix issue where runs on a VM would download the dataset even when a local dataset is provided. +- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training ### Removed From 81fafcba1308e71dafd715541ebff4a0e4b70865 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Wed, 8 Dec 2021 22:53:04 +0000 Subject: [PATCH 07/14] mypy --- InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index 72c6f9f6f..cd77312fc 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -85,9 +85,9 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning accelerator = trainer.accelerator_connector if accelerator.is_distributed: if accelerator.use_ddp: - self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) + self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore elif accelerator.use_dp: - self.evaluator = DataParallel(self.evaluator, device_ids=[pl_module.device]) + self.evaluator = DataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore else: rank_zero_warn("This type of distributed accelerator is not supported. " "The online evaluator will not synchronize across GPUs.") From dc4bd3e54c2524bec26732b61597e1c376e48ead Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 9 Dec 2021 00:26:52 +0000 Subject: [PATCH 08/14] test fix --- Tests/SSL/test_ssl_containers.py | 40 ++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 5c2c148a6..d92309d21 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -348,20 +348,26 @@ def test_online_evaluator_distributed() -> None: """ A very primitive type of test to check if the online evaluator uses the DDP flag correctly. """ - callback = SSLOnlineEvaluatorInnerEye(class_weights=None, - z_dim=1, - num_classes=2, - dataset="foo", - drop_p=0.2, - learning_rate=1e-5) - assert isinstance(callback.evaluator, Module) - assert not isinstance(callback.evaluator, DistributedDataParallel) - trainer = Trainer() - mock_module = mock.MagicMock(device=torch.device("cpu")) - callback.on_pretrain_routine_start(trainer, mock_module) - assert isinstance(callback.evaluator, Module) - assert not isinstance(callback.evaluator, DistributedDataParallel) - mock_module = mock.MagicMock(device=torch.device("cuda:0")) - trainer = Trainer(accelerator="ddp", gpus=2) - callback.on_pretrain_routine_start(trainer, mock_module) - assert isinstance(callback.evaluator, DistributedDataParallel) + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp: + callback = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + mock_ddp.assert_not_called() + + # Standard trainer without DDP + trainer = Trainer() + mock_module = mock.MagicMock(device=torch.device("cpu")) + callback.on_pretrain_routine_start(trainer, mock_module) + assert isinstance(callback.evaluator, Module) + mock_ddp.assert_not_called() + + # Trainer with DDP + mock_device = "fake_device" + mock_module = mock.MagicMock(device=mock_device) + trainer = Trainer(accelerator="ddp", gpus=2) + callback.on_pretrain_routine_start(trainer, mock_module) + # We still need to make DDP here because the constructor relies on having a process group available + mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[mock_device]) From 5890ddde2106203e3df81b0106646b7c238c8069 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 9 Dec 2021 00:27:08 +0000 Subject: [PATCH 09/14] typo --- Tests/SSL/test_ssl_containers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index d92309d21..e1ec76b22 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -369,5 +369,5 @@ def test_online_evaluator_distributed() -> None: mock_module = mock.MagicMock(device=mock_device) trainer = Trainer(accelerator="ddp", gpus=2) callback.on_pretrain_routine_start(trainer, mock_module) - # We still need to make DDP here because the constructor relies on having a process group available + # We still need to mock DDP here because the constructor relies on having a process group available mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[mock_device]) From c28bd0d7a97e0ee43a826f58b95decf6ed1e2d6e Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 9 Dec 2021 00:27:33 +0000 Subject: [PATCH 10/14] typo --- Tests/SSL/test_ssl_containers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index e1ec76b22..2b8c23479 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -346,7 +346,7 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No @pytest.mark.gpu def test_online_evaluator_distributed() -> None: """ - A very primitive type of test to check if the online evaluator uses the DDP flag correctly. + A very basic test to check if the online evaluator uses the DDP flag correctly. """ with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp: callback = SSLOnlineEvaluatorInnerEye(class_weights=None, From 0c16162522e5073f4288a3f7e24782cc3dd9beba Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 9 Dec 2021 00:30:56 +0000 Subject: [PATCH 11/14] test fix --- Tests/SSL/test_ssl_containers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 2b8c23479..2a372928f 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -348,7 +348,9 @@ def test_online_evaluator_distributed() -> None: """ A very basic test to check if the online evaluator uses the DDP flag correctly. """ - with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp: + mock_ddp_result = "mock_ddp_result" + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel", + return_value=mock_ddp_result) as mock_ddp: callback = SSLOnlineEvaluatorInnerEye(class_weights=None, z_dim=1, num_classes=2, @@ -368,6 +370,11 @@ def test_online_evaluator_distributed() -> None: mock_device = "fake_device" mock_module = mock.MagicMock(device=mock_device) trainer = Trainer(accelerator="ddp", gpus=2) + # Test the two flags that the internal logic of on_pretrain_routine_start uses + assert trainer.accelerator_connector.is_distributed + assert trainer.accelerator_connector.use_ddp callback.on_pretrain_routine_start(trainer, mock_module) + # Check that the evaluator has been turned into a DDP object # We still need to mock DDP here because the constructor relies on having a process group available mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[mock_device]) + assert callback.evaluator == mock_ddp_result From d958171db14005e11c3a271ed1edbed71a119930 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 9 Dec 2021 11:47:48 +0000 Subject: [PATCH 12/14] test and flake --- .../lightning_modules/ssl_online_evaluator.py | 2 -- Tests/SSL/test_ssl_containers.py | 34 ++++++++++++++----- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index cd77312fc..a9fc1a6c2 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -86,8 +86,6 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning if accelerator.is_distributed: if accelerator.use_ddp: self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore - elif accelerator.use_dp: - self.evaluator = DataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore else: rank_zero_warn("This type of distributed accelerator is not supported. " "The online evaluator will not synchronize across GPUs.") diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 2a372928f..0a42f9b14 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -12,7 +12,7 @@ import pytest import torch from pl_bolts.models.self_supervised.resnets import ResNet -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from torch.nn import Module from torch.nn.parallel import DistributedDataParallel @@ -344,13 +344,11 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No @pytest.mark.gpu -def test_online_evaluator_distributed() -> None: +def test_online_evaluator_not_distributed() -> None: """ - A very basic test to check if the online evaluator uses the DDP flag correctly. + Check if the online evaluator uses the DDP flag correctly when running not distributed """ - mock_ddp_result = "mock_ddp_result" - with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel", - return_value=mock_ddp_result) as mock_ddp: + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp: callback = SSLOnlineEvaluatorInnerEye(class_weights=None, z_dim=1, num_classes=2, @@ -361,14 +359,32 @@ def test_online_evaluator_distributed() -> None: # Standard trainer without DDP trainer = Trainer() + # Test the flag that the internal logic of on_pretrain_routine_start uses + assert not trainer.accelerator_connector.is_distributed mock_module = mock.MagicMock(device=torch.device("cpu")) callback.on_pretrain_routine_start(trainer, mock_module) assert isinstance(callback.evaluator, Module) mock_ddp.assert_not_called() + +@pytest.mark.gpu +def test_online_evaluator_distributed() -> None: + """ + Check if the online evaluator uses the DDP flag correctly when running distributed. + """ + mock_ddp_result = "mock_ddp_result" + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel", + return_value=mock_ddp_result) as mock_ddp: + callback = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + # Trainer with DDP - mock_device = "fake_device" - mock_module = mock.MagicMock(device=mock_device) + device = torch.device("cuda:0") + mock_module = mock.MagicMock(device=device) trainer = Trainer(accelerator="ddp", gpus=2) # Test the two flags that the internal logic of on_pretrain_routine_start uses assert trainer.accelerator_connector.is_distributed @@ -376,5 +392,5 @@ def test_online_evaluator_distributed() -> None: callback.on_pretrain_routine_start(trainer, mock_module) # Check that the evaluator has been turned into a DDP object # We still need to mock DDP here because the constructor relies on having a process group available - mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[mock_device]) + mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[device]) assert callback.evaluator == mock_ddp_result From 6832b8d7eddba68214b58a28599d9464c211dd38 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 9 Dec 2021 12:01:15 +0000 Subject: [PATCH 13/14] test and flake --- InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py | 2 +- Tests/SSL/test_ssl_containers.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index a9fc1a6c2..ad84919af 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -11,7 +11,7 @@ from pl_bolts.models.self_supervised.evaluator import SSLEvaluator from pytorch_lightning.utilities import rank_zero_warn from torch import Tensor as T -from torch.nn import DataParallel, functional as F +from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel from torchmetrics import Metric diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 0a42f9b14..598e90639 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -12,10 +12,9 @@ import pytest import torch from pl_bolts.models.self_supervised.resnets import ResNet -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from torch.nn import Module -from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import _LRScheduler from InnerEye.Common import fixed_paths From 06e7611ecea1288fdb6c8bd3def11f2aed94d5e1 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 9 Dec 2021 17:50:57 +0000 Subject: [PATCH 14/14] batchnorm, test fix --- .../lightning_modules/ssl_online_evaluator.py | 7 ++- Tests/SSL/test_ssl_containers.py | 48 +++++++++++-------- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py index ad84919af..d39e1b14f 100644 --- a/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py +++ b/InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py @@ -11,7 +11,7 @@ from pl_bolts.models.self_supervised.evaluator import SSLEvaluator from pytorch_lightning.utilities import rank_zero_warn from torch import Tensor as T -from torch.nn import functional as F +from torch.nn import SyncBatchNorm, functional as F from torch.nn.parallel import DistributedDataParallel from torchmetrics import Metric @@ -77,7 +77,9 @@ def on_load_checkpoint(self, def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ - Initializes modules and moves metrics and class weights to module device + Moves metrics and the online evaluator to the correct GPU. + If training happens via DDP, SyncBatchNorm is enabled for the online evaluator, and it is converted to + a DDP module. """ for metric in [*self.train_metrics, *self.val_metrics]: metric.to(device=pl_module.device) # type: ignore @@ -85,6 +87,7 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning accelerator = trainer.accelerator_connector if accelerator.is_distributed: if accelerator.use_ddp: + self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator) self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore else: rank_zero_warn("This type of distributed accelerator is not supported. " diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 598e90639..fb1ff4935 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -372,24 +372,30 @@ def test_online_evaluator_distributed() -> None: Check if the online evaluator uses the DDP flag correctly when running distributed. """ mock_ddp_result = "mock_ddp_result" - with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel", - return_value=mock_ddp_result) as mock_ddp: - callback = SSLOnlineEvaluatorInnerEye(class_weights=None, - z_dim=1, - num_classes=2, - dataset="foo", - drop_p=0.2, - learning_rate=1e-5) - - # Trainer with DDP - device = torch.device("cuda:0") - mock_module = mock.MagicMock(device=device) - trainer = Trainer(accelerator="ddp", gpus=2) - # Test the two flags that the internal logic of on_pretrain_routine_start uses - assert trainer.accelerator_connector.is_distributed - assert trainer.accelerator_connector.use_ddp - callback.on_pretrain_routine_start(trainer, mock_module) - # Check that the evaluator has been turned into a DDP object - # We still need to mock DDP here because the constructor relies on having a process group available - mock_ddp.assert_called_once_with(callback.evaluator, device_ids=[device]) - assert callback.evaluator == mock_ddp_result + mock_sync_result = "mock_sync_result" + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SyncBatchNorm.convert_sync_batchnorm", + return_value=mock_sync_result) as mock_sync: + with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel", + return_value=mock_ddp_result) as mock_ddp: + callback = SSLOnlineEvaluatorInnerEye(class_weights=None, + z_dim=1, + num_classes=2, + dataset="foo", + drop_p=0.2, + learning_rate=1e-5) + + # Trainer with DDP + device = torch.device("cuda:0") + mock_module = mock.MagicMock(device=device) + trainer = Trainer(accelerator="ddp", gpus=2) + # Test the two flags that the internal logic of on_pretrain_routine_start uses + assert trainer.accelerator_connector.is_distributed + assert trainer.accelerator_connector.use_ddp + original_evaluator = callback.evaluator + callback.on_pretrain_routine_start(trainer, mock_module) + # Check that SyncBatchNorm has been turned on + mock_sync.assert_called_once_with(original_evaluator) + # Check that the evaluator has been turned into a DDP object + # We still need to mock DDP here because the constructor relies on having a process group available + mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device]) + assert callback.evaluator == mock_ddp_result