Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ gets uploaded to AzureML, by skipping all test folders.

### Fixed
- ([#606](https://github.com/microsoft/InnerEye-DeepLearning/pull/606)) Bug fix: registered models do not include the hi-ml submodule
- ([#646](https://github.com/microsoft/InnerEye-DeepLearning/pull/646)) Workaround for bug in PL: CombinedLoader cannot be used for training data when using DDP
- ([#593](https://github.com/microsoft/InnerEye-DeepLearning/pull/593)) Bug fix for hi-ml 0.1.11 issue (#130): empty mount point is turned into ".", which fails the AML job
- ([#587](https://github.com/microsoft/InnerEye-DeepLearning/pull/587)) Bug fix for regression in AzureML's handling of environments: upgrade to hi-ml 0.1.11
- ([#625](https://github.com/microsoft/InnerEye-DeepLearning/pull/625)) updates to PandaDeepMIL to enable the use of a SSL pre-trained checkpoint and updated commit to hi-ml
Expand Down
32 changes: 24 additions & 8 deletions InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import logging
import os
from typing import Any, Callable, Optional, Sized, Union
from typing import Any, Callable, Dict, Optional, Sized, Union

import numpy as np
import torch
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pytorch_lightning import LightningDataModule
from pytorch_lightning.trainer.supporters import CombinedLoader
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader

from InnerEye.ML.SSL.utils import SSLDataModuleType

Expand Down Expand Up @@ -122,16 +122,30 @@ def __init__(self,
if use_balanced_loss_linear_head:
self.class_weights = self.linear_head_module.compute_class_weights()
self.batch_size = self.encoder_module.batch_size
self.train_loader_cycle_mode: Optional[str] = None
self._is_prepared = False

def prepare_data(self, *args: Any, **kwargs: Any) -> None:
"""
Saves files to data_dir
"""
if self._is_prepared:
return
self.encoder_module.prepare_data()
self.linear_head_module.prepare_data()
logging.info(f"Length of encoder train dataloader {len(self.encoder_module.train_dataloader())}")
logging.info(f"Length of linear head train dataloader {len(self.linear_head_module.train_dataloader())}")
len_encoder_train = len(self.encoder_module.train_dataloader())
len_linear_head_train = len(self.linear_head_module.train_dataloader())
logging.info(f"Length of encoder train dataloader {len_encoder_train}")
logging.info(f"Length of linear head train dataloader {len_linear_head_train}")
logging.info(f"Length of total train dataloader {len(self.train_dataloader())}")
# Workaround for a bug in PL: We can't use a CombinedLoader for the training data. Instead,
# need to return a dictionary and set a cycle mode flag on the trainer. This flag can only be computed
# once the data is prepared. We read this flag out later before we construct the Trainer object.
self.train_loader_cycle_mode = self._cycle_mode(len_encoder_train, len_linear_head_train)
self._is_prepared = True

def _cycle_mode(self, len_encoder: int, len_linear_head: int) -> str:
return "max_size_cycle" if len_encoder > len_linear_head else "min_size"

def get_combined_loader(self, encoder_loader: Sized, linear_head_loader: Sized) -> CombinedLoader:
"""
Expand All @@ -140,19 +154,21 @@ def get_combined_loader(self, encoder_loader: Sized, linear_head_loader: Sized)
:param encoder_loader: The dataloader to use for the SSL encoder.
:param linear_head_loader: The dataloader to use for the linear head.
"""
mode = "max_size_cycle" if len(encoder_loader) > len(linear_head_loader) else "min_size"
mode = self._cycle_mode(len(encoder_loader), len(linear_head_loader))
dataloaders = {
SSLDataModuleType.ENCODER: encoder_loader,
SSLDataModuleType.LINEAR_HEAD: linear_head_loader
}
return CombinedLoader(dataloaders, mode=mode)

def train_dataloader(self, *args: Any, **kwargs: Any) -> CombinedLoader: # type: ignore
def train_dataloader(self, *args: Any, **kwargs: Any) -> Dict[SSLDataModuleType, DataLoader]: # type: ignore
"""
The train dataloaders
"""
return self.get_combined_loader(encoder_loader=self.encoder_module.train_dataloader(),
linear_head_loader=self.linear_head_module.train_dataloader())
return {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: I would add also here a comment about the bug to explain why we return a dict instead of using the combined_datalaoder

SSLDataModuleType.ENCODER: self.encoder_module.train_dataloader(),
SSLDataModuleType.LINEAR_HEAD: self.linear_head_module.train_dataloader()
}

def val_dataloader(self, *args: Any, **kwargs: Any) -> CombinedLoader: # type: ignore
"""
Expand Down
7 changes: 4 additions & 3 deletions InnerEye/ML/SSL/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def get_encoder_output_dim(
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import (
SSLOnlineEvaluatorInnerEye,
)

batch = next(iter(dm.train_dataloader()))
batch = batch[SSLDataModuleType.LINEAR_HEAD] if isinstance(batch, dict) else batch # type: ignore
loaders = dm.train_dataloader()
loader = loaders[SSLDataModuleType.LINEAR_HEAD] if isinstance(loaders, dict) else loaders # type: ignore
iterator = iter(loader)
batch = next(iterator)
x, _ = SSLOnlineEvaluatorInnerEye.to_device(batch, device)
else:
x = torch.rand((1, 3, 256, 256)).to(device)
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class SSLContainer(LightningContainer):
use_balanced_binary_loss_for_linear_head = param.Boolean(default=False,
doc="Whether to use a balanced loss for the training of "
"the linear head")
num_workers = param.Integer(default=6, doc="Number of workers to use for dataloader processes.")
num_workers = param.Integer(default=4, doc="Number of workers to use for dataloader processes.")
is_debug_model = param.Boolean(default=False,
doc="If True, the training will be restricted to 1 batch per epoch."
"Used for debugging and tests.")
Expand Down
16 changes: 14 additions & 2 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, change_working_directory
from InnerEye.Common.resource_monitor import ResourceMonitor
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule
from InnerEye.ML.common import ARGS_TXT, AUTOSAVE_CHECKPOINT_FILE_NAME, ModelExecutionMode, \
VISUALIZATION_FOLDER
from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning
Expand Down Expand Up @@ -57,7 +58,8 @@ def write_args_file(config: Any, outputs_folder: Path) -> None:

def create_lightning_trainer(container: LightningContainer,
resume_from_checkpoint: Optional[Path] = None,
num_nodes: int = 1) -> \
num_nodes: int = 1,
multiple_trainloader_mode: str = "max_size_cycle") -> \
Tuple[Trainer, StoringLogger]:
"""
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
Expand Down Expand Up @@ -174,6 +176,7 @@ def create_lightning_trainer(container: LightningContainer,
detect_anomaly=container.detect_anomaly,
profiler=container.pl_profiler,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
multiple_trainloader_mode=multiple_trainloader_mode,
**additional_args)
return trainer, storing_logger

Expand Down Expand Up @@ -230,6 +233,14 @@ def model_train(checkpoint_path: Optional[Path],
container.before_training_on_local_rank_zero()
container.before_training_on_all_ranks()

# Workaround for a bug in PL 1.5.5: We need to pass the cycle mode for the training data as a trainer argument
# because training data that uses a CombinedLoader is not split correctly in DDP
multiple_trainloader_mode = "max_size_cycle"
if isinstance(data_module, CombinedDataModule):
data_module.prepare_data()
assert data_module.train_loader_cycle_mode is not None, "This field should be computed during prepare_data"
multiple_trainloader_mode = data_module.train_loader_cycle_mode

# Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
# training in the unit tests.d
old_environ = dict(os.environ)
Expand All @@ -238,7 +249,8 @@ def model_train(checkpoint_path: Optional[Path],
seed_everything(container.get_effective_random_seed())
trainer, storing_logger = create_lightning_trainer(container,
checkpoint_path,
num_nodes=num_nodes)
num_nodes=num_nodes,
multiple_trainloader_mode=multiple_trainloader_mode)
rank_info = ", ".join(f"{env}: {os.getenv(env)}"
for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK])
logging.info(f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}")
Expand Down
3 changes: 0 additions & 3 deletions Tests/SSL/test_data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,6 @@ def test_combined_data_module() -> None:
assert torch.isclose(combined_loader.class_weights,
torch.tensor([0.21, 0.79], dtype=torch.float32), atol=1e-3).all()

train_dataloader = combined_loader.train_dataloader()
assert isinstance(train_dataloader, CombinedLoader)

indices_classifier_module_short = []
val_dataloader = combined_loader.val_dataloader()
assert isinstance(val_dataloader, CombinedLoader)
Expand Down
12 changes: 8 additions & 4 deletions Tests/SSL/test_ssl_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None:
# Check the metrics that were recorded during training
# Note: It is possible that after the PyTorch 1.10 upgrade, we can't get parity between local runs and runs on
# the hosted build agents. If that suspicion is confirmed, we need to add branching for local and cloud results.
expected_metrics = {'simclr/val/loss': 2.8797268867492676,
'ssl_online_evaluator/val/loss': 2.272602081298828,
expected_metrics = {'simclr/val/loss': 2.8736934661865234,
'ssl_online_evaluator/val/loss': 2.2684895992279053,
'ssl_online_evaluator/val/AccuracyAtThreshold05': 0.20000000298023224,
'simclr/train/loss': 3.6261773109436035,
'simclr/learning_rate': 0.0,
Expand Down Expand Up @@ -615,8 +615,12 @@ def test_simclr_dataset_length(test_output_dirs: OutputFolderForTests,
model = container.create_model()
expected_num_train_iters = (num_encoder_images * 0.9) // encoder_batch_size
assert model.train_iters_per_epoch == expected_num_train_iters
train_loaders = container.get_data_module().train_dataloader()
assert isinstance(train_loaders, CombinedLoader)
data_module = container.get_data_module()
data_module.prepare_data()
train_loaders_dict = data_module.train_dataloader()
assert isinstance(train_loaders_dict, dict)
assert data_module.train_loader_cycle_mode
train_loaders = CombinedLoader(train_loaders_dict, mode=data_module.train_loader_cycle_mode)
assert len(train_loaders) == expected_num_train_iters
expected_num_val_iters = (num_encoder_images * 0.1) // encoder_batch_size
val_loaders = container.get_data_module().val_dataloader()
Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines/build-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:

- job: Linux
pool:
vmImage: 'ubuntu-18.04'
vmImage: 'ubuntu-20.04'
steps:
- template: build.yaml

Expand Down