diff --git a/CHANGELOG.md b/CHANGELOG.md index f11a127cb..fe38529f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,7 @@ gets uploaded to AzureML, by skipping all test folders. ### Fixed - ([#606](https://github.com/microsoft/InnerEye-DeepLearning/pull/606)) Bug fix: registered models do not include the hi-ml submodule +- ([#646](https://github.com/microsoft/InnerEye-DeepLearning/pull/646)) Workaround for bug in PL: CombinedLoader cannot be used for training data when using DDP - ([#593](https://github.com/microsoft/InnerEye-DeepLearning/pull/593)) Bug fix for hi-ml 0.1.11 issue (#130): empty mount point is turned into ".", which fails the AML job - ([#587](https://github.com/microsoft/InnerEye-DeepLearning/pull/587)) Bug fix for regression in AzureML's handling of environments: upgrade to hi-ml 0.1.11 - ([#625](https://github.com/microsoft/InnerEye-DeepLearning/pull/625)) updates to PandaDeepMIL to enable the use of a SSL pre-trained checkpoint and updated commit to hi-ml diff --git a/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py b/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py index c399b1b4b..b96d03bf2 100644 --- a/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py +++ b/InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py @@ -5,14 +5,14 @@ import logging import os -from typing import Any, Callable, Optional, Sized, Union +from typing import Any, Callable, Dict, Optional, Sized, Union import numpy as np import torch from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pytorch_lightning import LightningDataModule from pytorch_lightning.trainer.supporters import CombinedLoader -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader from InnerEye.ML.SSL.utils import SSLDataModuleType @@ -122,16 +122,30 @@ def __init__(self, if use_balanced_loss_linear_head: self.class_weights = self.linear_head_module.compute_class_weights() self.batch_size = self.encoder_module.batch_size + self.train_loader_cycle_mode: Optional[str] = None + self._is_prepared = False def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ Saves files to data_dir """ + if self._is_prepared: + return self.encoder_module.prepare_data() self.linear_head_module.prepare_data() - logging.info(f"Length of encoder train dataloader {len(self.encoder_module.train_dataloader())}") - logging.info(f"Length of linear head train dataloader {len(self.linear_head_module.train_dataloader())}") + len_encoder_train = len(self.encoder_module.train_dataloader()) + len_linear_head_train = len(self.linear_head_module.train_dataloader()) + logging.info(f"Length of encoder train dataloader {len_encoder_train}") + logging.info(f"Length of linear head train dataloader {len_linear_head_train}") logging.info(f"Length of total train dataloader {len(self.train_dataloader())}") + # Workaround for a bug in PL: We can't use a CombinedLoader for the training data. Instead, + # need to return a dictionary and set a cycle mode flag on the trainer. This flag can only be computed + # once the data is prepared. We read this flag out later before we construct the Trainer object. + self.train_loader_cycle_mode = self._cycle_mode(len_encoder_train, len_linear_head_train) + self._is_prepared = True + + def _cycle_mode(self, len_encoder: int, len_linear_head: int) -> str: + return "max_size_cycle" if len_encoder > len_linear_head else "min_size" def get_combined_loader(self, encoder_loader: Sized, linear_head_loader: Sized) -> CombinedLoader: """ @@ -140,19 +154,21 @@ def get_combined_loader(self, encoder_loader: Sized, linear_head_loader: Sized) :param encoder_loader: The dataloader to use for the SSL encoder. :param linear_head_loader: The dataloader to use for the linear head. """ - mode = "max_size_cycle" if len(encoder_loader) > len(linear_head_loader) else "min_size" + mode = self._cycle_mode(len(encoder_loader), len(linear_head_loader)) dataloaders = { SSLDataModuleType.ENCODER: encoder_loader, SSLDataModuleType.LINEAR_HEAD: linear_head_loader } return CombinedLoader(dataloaders, mode=mode) - def train_dataloader(self, *args: Any, **kwargs: Any) -> CombinedLoader: # type: ignore + def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]: # type: ignore """ The train dataloaders """ - return self.get_combined_loader(encoder_loader=self.encoder_module.train_dataloader(), - linear_head_loader=self.linear_head_module.train_dataloader()) + return { + SSLDataModuleType.ENCODER: self.encoder_module.train_dataloader(), + SSLDataModuleType.LINEAR_HEAD: self.linear_head_module.train_dataloader() + } def val_dataloader(self, *args: Any, **kwargs: Any) -> CombinedLoader: # type: ignore """ diff --git a/InnerEye/ML/SSL/encoders.py b/InnerEye/ML/SSL/encoders.py index ab308e7fa..9a83836de 100644 --- a/InnerEye/ML/SSL/encoders.py +++ b/InnerEye/ML/SSL/encoders.py @@ -79,9 +79,10 @@ def get_encoder_output_dim( from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import ( SSLOnlineEvaluatorInnerEye, ) - - batch = next(iter(dm.train_dataloader())) - batch = batch[SSLDataModuleType.LINEAR_HEAD] if isinstance(batch, dict) else batch # type: ignore + loaders = dm.train_dataloader() + loader = loaders[SSLDataModuleType.LINEAR_HEAD] if isinstance(loaders, dict) else loaders # type: ignore + iterator = iter(loader) + batch = next(iterator) x, _ = SSLOnlineEvaluatorInnerEye.to_device(batch, device) else: x = torch.rand((1, 3, 256, 256)).to(device) diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index 3257110db..e2b21a646 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -83,7 +83,7 @@ class SSLContainer(LightningContainer): use_balanced_binary_loss_for_linear_head = param.Boolean(default=False, doc="Whether to use a balanced loss for the training of " "the linear head") - num_workers = param.Integer(default=6, doc="Number of workers to use for dataloader processes.") + num_workers = param.Integer(default=4, doc="Number of workers to use for dataloader processes.") is_debug_model = param.Boolean(default=False, doc="If True, the training will be restricted to 1 batch per epoch." "Used for debugging and tests.") diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index f247dcf2d..89aab6ea8 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -17,6 +17,7 @@ from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, change_working_directory from InnerEye.Common.resource_monitor import ResourceMonitor +from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule from InnerEye.ML.common import ARGS_TXT, AUTOSAVE_CHECKPOINT_FILE_NAME, ModelExecutionMode, \ VISUALIZATION_FOLDER from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning @@ -57,7 +58,8 @@ def write_args_file(config: Any, outputs_folder: Path) -> None: def create_lightning_trainer(container: LightningContainer, resume_from_checkpoint: Optional[Path] = None, - num_nodes: int = 1) -> \ + num_nodes: int = 1, + multiple_trainloader_mode: str = "max_size_cycle") -> \ Tuple[Trainer, StoringLogger]: """ Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers @@ -174,6 +176,7 @@ def create_lightning_trainer(container: LightningContainer, detect_anomaly=container.detect_anomaly, profiler=container.pl_profiler, resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None, + multiple_trainloader_mode=multiple_trainloader_mode, **additional_args) return trainer, storing_logger @@ -230,6 +233,14 @@ def model_train(checkpoint_path: Optional[Path], container.before_training_on_local_rank_zero() container.before_training_on_all_ranks() + # Workaround for a bug in PL 1.5.5: We need to pass the cycle mode for the training data as a trainer argument + # because training data that uses a CombinedLoader is not split correctly in DDP + multiple_trainloader_mode = "max_size_cycle" + if isinstance(data_module, CombinedDataModule): + data_module.prepare_data() + assert data_module.train_loader_cycle_mode is not None, "This field should be computed during prepare_data" + multiple_trainloader_mode = data_module.train_loader_cycle_mode + # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second # training in the unit tests.d old_environ = dict(os.environ) @@ -238,7 +249,8 @@ def model_train(checkpoint_path: Optional[Path], seed_everything(container.get_effective_random_seed()) trainer, storing_logger = create_lightning_trainer(container, checkpoint_path, - num_nodes=num_nodes) + num_nodes=num_nodes, + multiple_trainloader_mode=multiple_trainloader_mode) rank_info = ", ".join(f"{env}: {os.getenv(env)}" for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK]) logging.info(f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}") diff --git a/Tests/SSL/test_data_modules.py b/Tests/SSL/test_data_modules.py index 2fa750f93..85f165b0f 100644 --- a/Tests/SSL/test_data_modules.py +++ b/Tests/SSL/test_data_modules.py @@ -216,9 +216,6 @@ def test_combined_data_module() -> None: assert torch.isclose(combined_loader.class_weights, torch.tensor([0.21, 0.79], dtype=torch.float32), atol=1e-3).all() - train_dataloader = combined_loader.train_dataloader() - assert isinstance(train_dataloader, CombinedLoader) - indices_classifier_module_short = [] val_dataloader = combined_loader.val_dataloader() assert isinstance(val_dataloader, CombinedLoader) diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index b158f7bed..95968be36 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -130,8 +130,8 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None: # Check the metrics that were recorded during training # Note: It is possible that after the PyTorch 1.10 upgrade, we can't get parity between local runs and runs on # the hosted build agents. If that suspicion is confirmed, we need to add branching for local and cloud results. - expected_metrics = {'simclr/val/loss': 2.8797268867492676, - 'ssl_online_evaluator/val/loss': 2.272602081298828, + expected_metrics = {'simclr/val/loss': 2.8736934661865234, + 'ssl_online_evaluator/val/loss': 2.2684895992279053, 'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.20000000298023224, 'simclr/train/loss': 3.6261773109436035, 'simclr/learning_rate': 0.0, @@ -615,8 +615,12 @@ def test_simclr_dataset_length(test_output_dirs: OutputFolderForTests, model = container.create_model() expected_num_train_iters = (num_encoder_images * 0.9) // encoder_batch_size assert model.train_iters_per_epoch == expected_num_train_iters - train_loaders = container.get_data_module().train_dataloader() - assert isinstance(train_loaders, CombinedLoader) + data_module = container.get_data_module() + data_module.prepare_data() + train_loaders_dict = data_module.train_dataloader() + assert isinstance(train_loaders_dict, dict) + assert data_module.train_loader_cycle_mode + train_loaders = CombinedLoader(train_loaders_dict, mode=data_module.train_loader_cycle_mode) assert len(train_loaders) == expected_num_train_iters expected_num_val_iters = (num_encoder_images * 0.1) // encoder_batch_size val_loaders = container.get_data_module().val_dataloader() diff --git a/azure-pipelines/build-pr.yml b/azure-pipelines/build-pr.yml index 04b85fee6..da6454f42 100644 --- a/azure-pipelines/build-pr.yml +++ b/azure-pipelines/build-pr.yml @@ -31,7 +31,7 @@ jobs: - job: Linux pool: - vmImage: 'ubuntu-18.04' + vmImage: 'ubuntu-20.04' steps: - template: build.yaml