diff --git a/CHANGELOG.md b/CHANGELOG.md index dc41842f2..92f6631b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ multiple large checkpoints can time out. ### Added +- ([#483](https://github.com/microsoft/InnerEye-DeepLearning/pull/483)) Allow cross validation with 'bring your own' Lightning models (without ensemble building). - ([#489](https://github.com/microsoft/InnerEye-DeepLearning/pull/489)) Remove portal query for outliers. - ([#488](https://github.com/microsoft/InnerEye-DeepLearning/pull/488)) Better handling of missing seriesId in segmentation cross validation reports. - ([#454](https://github.com/microsoft/InnerEye-DeepLearning/pull/454)) Checking that labels are mutually exclusive. diff --git a/InnerEye/ML/configs/other/HelloContainer.py b/InnerEye/ML/configs/other/HelloContainer.py index b911ccad7..6b6d5e9cb 100644 --- a/InnerEye/ML/configs/other/HelloContainer.py +++ b/InnerEye/ML/configs/other/HelloContainer.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch @@ -12,6 +12,7 @@ from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import StepLR, _LRScheduler from torch.utils.data import DataLoader, Dataset +from sklearn.model_selection import KFold from InnerEye.Common import fixed_paths from InnerEye.ML.lightning_container import LightningContainer @@ -31,15 +32,14 @@ class HelloDataset(Dataset): # y = 0.2 * x + 0.1 * torch.randn(x.size()) # xy = torch.cat((x, y), dim=1) # np.savetxt("InnerEye/ML/configs/other/hellocontainer.csv", xy.numpy(), delimiter=",") - def __init__(self, root_folder: Path, start_index: int, end_index: int) -> None: + def __init__(self, raw_data: List[List[float]]) -> None: """ Creates the 1-dim regression dataset. - :param root_folder: The folder in which the data file lives ("hellocontainer.csv") - :param start_index: The first row to read. - :param end_index: The last row to read (exclusive) + :param raw_data: The raw data, e.g. from a cross validation split or loaded from file. This + must be numeric data which can be converted into a tensor. See the static method + from_path_and_indexes for an example call. """ - super().__init__() - raw_data = np.loadtxt(str(root_folder / "hellocontainer.csv"), delimiter=",")[start_index:end_index] + super().__init__() self.data = torch.tensor(raw_data, dtype=torch.float) def __len__(self) -> int: @@ -48,17 +48,67 @@ def __len__(self) -> int: def __getitem__(self, item: int) -> Dict[str, torch.Tensor]: return {'x': self.data[item][0:1], 'y': self.data[item][1:2]} + @staticmethod + def from_path_and_indexes( + root_folder: Path, + start_index: int, + end_index: int) -> 'HelloDataset': + ''' + Static method to instantiate a HelloDataset from the root folder with the start and end indexes. + :param root_folder: The folder in which the data file lives ("hellocontainer.csv") + :param start_index: The first row to read. + :param end_index: The last row to read (exclusive) + :return: A new instance based on the root folder and the start and end indexes. + ''' + raw_data = np.loadtxt(root_folder / "hellocontainer.csv", delimiter=",")[start_index:end_index] + return HelloDataset(raw_data) + class HelloDataModule(LightningDataModule): """ A data module that gives the training, validation and test data for a simple 1-dim regression task. + If not using cross validation a basic 50% / 20% / 30% split between train, validation, and test data + is made on the whole dataset. + For cross validation (if required) we use k-fold cross-validation. The test set remains unchanged + while the training and validation data cycle through the k-folds of the remaining data. """ - - def __init__(self, root_folder: Path) -> None: + def __init__( + self, + root_folder: Path, + number_of_cross_validation_splits: int = 0, + cross_validation_split_index: int = 0) -> None: super().__init__() - self.train = HelloDataset(root_folder, start_index=0, end_index=50) - self.val = HelloDataset(root_folder, start_index=50, end_index=70) - self.test = HelloDataset(root_folder, start_index=70, end_index=100) + if number_of_cross_validation_splits <= 1: + # For 0 or 1 splits just use the default values on the whole data-set. + self.train = HelloDataset.from_path_and_indexes(root_folder, start_index=0, end_index=50) + self.val = HelloDataset.from_path_and_indexes(root_folder, start_index=50, end_index=70) + self.test = HelloDataset.from_path_and_indexes(root_folder, start_index=70, end_index=100) + else: + # Raise exceptions for unreasonable values + if cross_validation_split_index >= number_of_cross_validation_splits: + raise IndexError(f"The cross_validation_split_index ({cross_validation_split_index}) is too large " + f"given the number_of_cross_validation_splits ({number_of_cross_validation_splits}) requested") + raw_data = np.loadtxt(root_folder / "hellocontainer.csv", delimiter=",") + np.random.seed(42) + np.random.shuffle(raw_data) + if number_of_cross_validation_splits >= len(raw_data): + raise ValueError(f"Asked for {number_of_cross_validation_splits} cross validation splits from a " + f"dataset of length {len(raw_data)}") + # Hold out the last 30% as test data + self.test = HelloDataset(raw_data[70:100]) + # Create k-folds from the remaining 70% of the data-set. Use one for the validation + # data and the rest for the training data + raw_data_remaining = raw_data[0:70] + k_fold = KFold(n_splits=number_of_cross_validation_splits) + train_indexes, val_indexes = list(k_fold.split(raw_data_remaining))[cross_validation_split_index] + self.train = HelloDataset(raw_data_remaining[train_indexes]) + self.val = HelloDataset(raw_data_remaining[val_indexes]) + + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + pass + + def setup(self, stage: Optional[str] = None) -> None: + pass def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: return DataLoader(self.train, batch_size=5) @@ -135,7 +185,7 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: This method is part of the standard PyTorch Lightning interface. For an introduction, please see https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html It returns the PyTorch optimizer(s) and learning rate scheduler(s) that should be used for training. -= """ + """ optimizer = Adam(self.parameters(), lr=1e-1) scheduler = StepLR(optimizer, step_size=20, gamma=0.5) return [optimizer], [scheduler] @@ -203,10 +253,19 @@ def create_model(self) -> LightningModule: return HelloRegression() # This method must be overridden by any subclass of LightningContainer. It returns a data module, which - # in turn contains 3 data loaders for training, validation, and test set. + # in turn contains 3 data loaders for training, validation, and test set. + # + # If the container is used for cross validation then this method must handle the cross validation splits. + # Because this deals with data loaders, not loaded data, we cannot check automatically that cross validation is + # handled correctly within the LightningContainer base class, i.e. if you forget to do the cross validation split + # in your subclass nothing will fail, but each child run will be identical since they will each be given the full + # dataset. def get_data_module(self) -> LightningDataModule: assert self.local_dataset is not None - return HelloDataModule(root_folder=self.local_dataset) # type: ignore + return HelloDataModule( + root_folder=self.local_dataset, + number_of_cross_validation_splits=self.number_of_cross_validation_splits, + cross_validation_split_index=self.cross_validation_split_index) # type: ignore # This is an optional override: This report creation method can read out any files that were written during # training, and cook them into a nice looking report. Here, the report is a simple text file. diff --git a/InnerEye/ML/lightning_container.py b/InnerEye/ML/lightning_container.py index f2f081a45..bfc1b4e09 100644 --- a/InnerEye/ML/lightning_container.py +++ b/InnerEye/ML/lightning_container.py @@ -11,8 +11,12 @@ from pytorch_lightning import LightningDataModule, LightningModule from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler +from azureml.core import ScriptRunConfig +from azureml.train.hyperdrive import GridParameterSampling, HyperDriveConfig, PrimaryMetricGoal, choice +from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY from InnerEye.Common.generic_parsing import GenericConfig, create_from_matching_params +from InnerEye.Common.metrics_constants import TrackedMetrics from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, \ WorkflowParams, load_checkpoint @@ -175,6 +179,9 @@ def get_data_module(self) -> LightningDataModule: The format of the data is not specified any further. The method must take cross validation into account, and ensure that logic to create training and validation sets takes cross validation with a given number of splits is correctly taken care of. + Because the method deals with data loaders, not loaded data, we cannot check automatically that cross validation + is handled correctly within the base class, i.e. if the cross validation split is not handled in the method then + nothing will fail, but each child run will be identical since they will each be given the full dataset. :return: A LightningDataModule """ return None # type: ignore @@ -200,6 +207,12 @@ def get_trainer_arguments(self) -> Dict[str, Any]: """ return dict() + def get_parameter_search_hyperdrive_config(self, _: ScriptRunConfig) -> HyperDriveConfig: # type: ignore + """ + Parameter search is not implemented. It should be implemented in a sub class if needed. + """ + raise NotImplementedError("Parameter search is not implemented. It should be implemented in a sub class if needed.") + def create_report(self) -> None: """ This method is called after training and testing has been completed. It can aggregate all files that were @@ -288,6 +301,38 @@ def create_lightning_module_and_store(self) -> None: self._model._optimizer_params = create_from_matching_params(self, OptimizerParams) self._model._trainer_params = create_from_matching_params(self, TrainerParams) + def get_cross_validation_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig: + """ + Returns a configuration for AzureML Hyperdrive that varies the cross validation split index. + Because this adds a val/Loss metric it is important that when subclassing LightningContainer + your implementeation of LightningModule logs val/Loss. There is an example of this in + HelloRegression's validation_step method. + :param run_config: The AzureML run configuration object that training for an individual model. + :return: A hyperdrive configuration object. + """ + return HyperDriveConfig( + run_config=run_config, + hyperparameter_sampling=GridParameterSampling( + parameter_space={ + CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY: choice(list(range(self.number_of_cross_validation_splits))) + }), + primary_metric_name=TrackedMetrics.Val_Loss.value, + primary_metric_goal=PrimaryMetricGoal.MINIMIZE, + max_total_runs=self.number_of_cross_validation_splits + ) + + def get_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig: + """ + Returns the HyperDrive config for either parameter search or cross validation + (if number_of_cross_validation_splits > 1). + :param run_config: AzureML estimator + :return: HyperDriveConfigs + """ + if self.perform_cross_validation: + return self.get_cross_validation_hyperdrive_config(run_config) + else: + return self.get_parameter_search_hyperdrive_config(run_config) + def __str__(self) -> str: """Returns a string describing the present object, as a list of key: value strings.""" arguments_str = "\nContainer:\n" diff --git a/InnerEye/ML/model_config_base.py b/InnerEye/ML/model_config_base.py index 0454233e4..edd849e7f 100644 --- a/InnerEye/ML/model_config_base.py +++ b/InnerEye/ML/model_config_base.py @@ -150,13 +150,6 @@ def create_model(self) -> Any: # because this would prevent us from easily instantiating this class in tests. raise NotImplementedError("create_model must be overridden") - def get_total_number_of_cross_validation_runs(self) -> int: - """ - Returns the total number of HyperDrive/offline runs required to sample the entire - cross validation parameter space. - """ - return self.number_of_cross_validation_splits - def get_cross_validation_hyperdrive_sampler(self) -> GridParameterSampling: """ Returns the cross validation sampler, required to sample the entire parameter space for cross validation. @@ -176,7 +169,7 @@ def get_cross_validation_hyperdrive_config(self, run_config: ScriptRunConfig) -> hyperparameter_sampling=self.get_cross_validation_hyperdrive_sampler(), primary_metric_name=TrackedMetrics.Val_Loss.value, primary_metric_goal=PrimaryMetricGoal.MINIMIZE, - max_total_runs=self.get_total_number_of_cross_validation_runs() + max_total_runs=self.number_of_cross_validation_splits ) def get_cross_validation_dataset_splits(self, dataset_split: DatasetSplits) -> DatasetSplits: diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index a46473285..b6f018f13 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -802,7 +802,7 @@ def are_sibling_runs_finished(self) -> bool: """ if (not self.is_offline_run) \ and (azure_util.is_cross_validation_child_run(RUN_CONTEXT)): - n_splits = self.innereye_config.get_total_number_of_cross_validation_runs() + n_splits = self.innereye_config.number_of_cross_validation_splits child_runs = azure_util.fetch_child_runs(PARENT_RUN_CONTEXT, expected_number_cross_validation_splits=n_splits) pending_runs = [x.id for x in child_runs diff --git a/InnerEye/ML/runner.py b/InnerEye/ML/runner.py index 0cac30dd8..9ba94d2d4 100755 --- a/InnerEye/ML/runner.py +++ b/InnerEye/ML/runner.py @@ -192,8 +192,6 @@ def run(self) -> Tuple[Optional[DeepLearningConfig], Optional[Run]]: user_agent.append(azure_util.INNEREYE_SDK_NAME, azure_util.INNEREYE_SDK_VERSION) self.parse_and_load_model() if self.lightning_container.perform_cross_validation: - if self.model_config is None: - raise NotImplementedError("Cross validation for LightingContainer models is not yet supported.") # force hyperdrive usage if performing cross validation self.azure_config.hyperdrive = True run_object: Optional[Run] = None @@ -219,14 +217,14 @@ def submit_to_azureml(self) -> Run: if isinstance(self.model_config, DeepLearningConfig) and not self.lightning_container.azure_dataset_id: raise ValueError("When running an InnerEye built-in model in AzureML, the 'azure_dataset_id' " "property must be set.") - hyperdrive_func = lambda run_config: self.model_config.get_hyperdrive_config(run_config) # type: ignore source_config = SourceConfig( root_folder=self.project_root, entry_script=Path(sys.argv[0]).resolve(), conda_dependencies_files=get_all_environment_files(self.project_root), - hyperdrive_config_func=hyperdrive_func, + hyperdrive_config_func=(self.model_config.get_hyperdrive_config if self.model_config + else self.lightning_container.get_hyperdrive_config), # For large jobs, upload of results can time out because of large checkpoint files. Default is 600 - upload_timeout_seconds=86400, + upload_timeout_seconds=86400 ) source_config.set_script_params_except_submit_flag() # Reduce the size of the snapshot by adding unused folders to amlignore. The Test* subfolders are only needed diff --git a/InnerEye/ML/visualizers/plot_cross_validation.py b/InnerEye/ML/visualizers/plot_cross_validation.py index 143d05925..baf09eb7d 100644 --- a/InnerEye/ML/visualizers/plot_cross_validation.py +++ b/InnerEye/ML/visualizers/plot_cross_validation.py @@ -440,7 +440,7 @@ def crossval_config_from_model_config(train_config: DeepLearningConfig) -> PlotC model_category=train_config.model_category, epoch=epoch, should_validate=False, - number_of_cross_validation_splits=train_config.get_total_number_of_cross_validation_runs()) + number_of_cross_validation_splits=train_config.number_of_cross_validation_splits) def get_config_and_results_for_offline_runs(train_config: DeepLearningConfig) -> OfflineCrossvalConfigAndFiles: diff --git a/Tests/ML/configs/other/test_hello_container.py b/Tests/ML/configs/other/test_hello_container.py new file mode 100644 index 000000000..c32b302ea --- /dev/null +++ b/Tests/ML/configs/other/test_hello_container.py @@ -0,0 +1,86 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +import torch +import pytest +from typing import Optional +from InnerEye.ML.configs.other.HelloContainer import HelloContainer, HelloDataModule + + +def test_cross_validation_splits() -> None: + ''' + Test that the exemplar HelloDataModule correctly splits the data set for cross validation. + + We test that unreasonable values raise an exception (i.e. a split index greater than the number of splits, + or requesting a number of splits larger than the available data). + - We test that between splits the train and validation data differ, and the test data remains the same. + - We test that all the data is used in each split. + - We test that across all the splits the validation data use all of the non-test data. + ''' + # Get full data-set for comparison + root_folder = HelloContainer().local_dataset + assert root_folder is not None + data_module_no_xval = HelloDataModule(root_folder=root_folder) + + for number_of_cross_validation_splits in [0, 1, 5]: + previous_data_module_xval: Optional[HelloDataModule] = None + for cross_validation_split_index in range(number_of_cross_validation_splits): + data_module_xval = HelloDataModule( + root_folder=root_folder, + cross_validation_split_index=cross_validation_split_index, + number_of_cross_validation_splits=number_of_cross_validation_splits) + _assert_total_data_identical( + data_module_no_xval, + data_module_xval, + f"Total data mismatch for cross validation split ({number_of_cross_validation_splits}, {cross_validation_split_index})") + if number_of_cross_validation_splits <= 1: + break + if previous_data_module_xval: + _check_train_val_test(previous_data_module_xval, data_module_xval) + previous_data_module_xval = data_module_xval + if cross_validation_split_index == 0: + accrued_val_data = data_module_xval.val.data + else: + accrued_val_data = torch.cat((accrued_val_data, data_module_xval.val.data), dim=0) + if cross_validation_split_index == number_of_cross_validation_splits - 1: + all_non_test_data = torch.cat((data_module_xval.train.data, data_module_xval.val.data), dim=0) + msg = "Accrued validation sets from all the cross validations does not match the total non-test data" + assert torch.equal(torch.sort(all_non_test_data, dim=0)[0], torch.sort(accrued_val_data, dim=0)[0]), msg + + with pytest.raises(IndexError) as index_error: + data_module_xval = HelloDataModule( + root_folder=root_folder, + cross_validation_split_index=6, + number_of_cross_validation_splits=5) + assert "too large given the number_of_cross_validation_splits" in str(index_error.value) + + with pytest.raises(ValueError) as value_error: + data_module_xval = HelloDataModule( + root_folder=root_folder, + cross_validation_split_index=0, + number_of_cross_validation_splits=10_000) + assert "Asked for 10000 cross validation splits from a dataset of length" in str(value_error.value) + + +def _assert_total_data_identical(dm1: HelloDataModule, dm2: HelloDataModule, msg: str) -> None: + ''' + Check that the total of the two HelloDataModule's train, val, and test data is identical + ''' + all_data1 = torch.cat((dm1.train.data, dm1.val.data, dm1.test.data), dim=0) + all_data2 = torch.cat((dm2.train.data, dm2.val.data, dm2.test.data), dim=0) + all_data1_sorted, _ = torch.sort(all_data1, dim=0) + all_data2_sorted, _ = torch.sort(all_data2, dim=0) + assert torch.equal(all_data1_sorted, all_data2_sorted), msg + + +def _check_train_val_test(dm1: HelloDataModule, dm2: HelloDataModule) -> None: + ''' + Check that the two HelloDataModule's train and val data is different, but that their test data is identical + ''' + msg = "Two cross validation sets have the same training data" + assert not torch.equal(torch.sort(dm1.train.data, dim=0)[0], torch.sort(dm2.train.data, dim=0)[0]), msg + msg = "Two cross validation sets have the same validation data" + assert not torch.equal(torch.sort(dm1.val.data, dim=0)[0], torch.sort(dm2.val.data, dim=0)[0]), msg + msg = "Two cross validation sets have differing test data" + assert torch.equal(torch.sort(dm1.test.data, dim=0)[0], torch.sort(dm2.test.data, dim=0)[0]), msg diff --git a/Tests/ML/models/test_scalar_model.py b/Tests/ML/models/test_scalar_model.py index 6194aa7ad..12201ac9a 100644 --- a/Tests/ML/models/test_scalar_model.py +++ b/Tests/ML/models/test_scalar_model.py @@ -293,7 +293,7 @@ def test_run_ml_with_classification_model(test_output_dirs: OutputFolderForTests config_and_files = get_config_and_results_for_offline_runs(config) result_files = config_and_files.files # One file for VAL, one for TRAIN and one for TEST for each child run - assert len(result_files) == config.get_total_number_of_cross_validation_runs() * 3 + assert len(result_files) == config.number_of_cross_validation_splits * 3 for file in result_files: assert file.dataset_csv_file is not None assert file.dataset_csv_file.exists() @@ -526,7 +526,7 @@ def test_is_offline_cross_val_parent_run(offline_parent_cv_run: bool) -> None: def _check_offline_cross_validation_output_files(train_config: ScalarModelBase) -> None: metrics: Dict[ModelExecutionMode, List[pd.DataFrame]] = dict() root = Path(train_config.file_system_config.outputs_folder) - for x in range(train_config.get_total_number_of_cross_validation_runs()): + for x in range(train_config.number_of_cross_validation_splits): expected_outputs_folder = root / str(x) assert expected_outputs_folder.exists() for m in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL, ModelExecutionMode.TEST]: diff --git a/Tests/ML/runners/test_runner.py b/Tests/ML/runners/test_runner.py index 186c3ea27..a0d8eede9 100644 --- a/Tests/ML/runners/test_runner.py +++ b/Tests/ML/runners/test_runner.py @@ -4,15 +4,18 @@ # ------------------------------------------------------------------------------------------ import logging import time +from unittest import mock import pytest +from azureml.train.hyperdrive.runconfig import HyperDriveConfig -from InnerEye.Common import common_util +from InnerEye.Common import common_util, fixed_paths from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, ModelExecutionMode from InnerEye.ML.metrics import InferenceMetricsForSegmentation from InnerEye.ML.run_ml import MLRunner +from InnerEye.ML.runner import Runner from Tests.ML.configs.DummyModel import DummyModel from Tests.ML.util import get_default_checkpoint_handler from Tests.ML.utils.test_model_util import create_model_and_store_checkpoint @@ -71,3 +74,22 @@ def test_logging_to_file(test_output_dirs: OutputFolderForTests) -> None: assert file_path.exists() assert log_line in file_path.read_text() assert should_not_be_present not in file_path.read_text() + + +def test_cross_validation_for_LightingContainer_models_is_supported() -> None: + ''' + Prior to https://github.com/microsoft/InnerEye-DeepLearning/pull/483 we raised an exception in + runner.run when cross validation was attempted on a lightning container. This test checks that + we do not raise the exception anymore, and instead pass on a cross validation hyperdrive config + to azure_runner's submit_to_azureml method. + ''' + args_list = ["--model=HelloContainer", "--number_of_cross_validation_splits=5", "--azureml=True"] + with mock.patch("sys.argv", [""] + args_list): + runner = Runner(project_root=fixed_paths.repository_root_directory(), yaml_config_file=fixed_paths.SETTINGS_YAML_FILE) + with mock.patch("InnerEye.Azure.azure_runner.create_and_submit_experiment", return_value=None) as create_and_submit_experiment_patch: + runner.run() + assert runner.lightning_container.model_name == 'HelloContainer' + assert runner.lightning_container.number_of_cross_validation_splits == 5 + args, _ = create_and_submit_experiment_patch.call_args + script_run_config = args[1] + assert isinstance(script_run_config, HyperDriveConfig) diff --git a/Tests/ML/test_lightning_containers.py b/Tests/ML/test_lightning_containers.py index 51a852d74..c8c53a3b3 100644 --- a/Tests/ML/test_lightning_containers.py +++ b/Tests/ML/test_lightning_containers.py @@ -10,6 +10,8 @@ import pandas as pd import pytest from pytorch_lightning import LightningModule +from azureml.core import ScriptRunConfig +from azureml.train.hyperdrive.runconfig import HyperDriveConfig from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.ML.common import ModelExecutionMode @@ -19,7 +21,7 @@ from InnerEye.ML.model_config_base import ModelConfigBase from InnerEye.ML.run_ml import MLRunner from Tests.ML.configs.DummyModel import DummyModel -from Tests.ML.configs.lightning_test_containers import DummyContainerWithHooks, DummyContainerWithModel, \ +from Tests.ML.configs.lightning_test_containers import DummyContainerWithAzureDataset, DummyContainerWithHooks, DummyContainerWithModel, \ DummyContainerWithPlainLightning from Tests.ML.util import default_runner @@ -280,3 +282,33 @@ def test_container_hooks(test_output_dirs: OutputFolderForTests) -> None: # only check that they have all been called. for file in ["global_rank_zero.txt", "local_rank_zero.txt", "all_ranks.txt"]: assert (runner.container.outputs_folder / file).is_file(), f"Missing file: {file}" + +@pytest.mark.parametrize("number_of_cross_validation_splits", [0, 2]) +def test_get_hyperdrive_config(number_of_cross_validation_splits: int, + test_output_dirs: OutputFolderForTests) -> None: + """ + Testing that the hyperdrive config returned for the lightnig container is right for submitting + to AzureML. + + Note that because the function get_hyperdrive_config now lives in the super class WorkflowParams, + it is also tested for other aspects of functionality by a test of the same name in + Tests.ML.test_model_config_base. + """ + container = DummyContainerWithAzureDataset() + container.number_of_cross_validation_splits = number_of_cross_validation_splits + run_config = ScriptRunConfig( + source_directory=str(test_output_dirs.root_dir), + script=str(Path("something.py")), + arguments=["foo"], + compute_target="EnormousCluster") + if number_of_cross_validation_splits == 0: + with pytest.raises(NotImplementedError) as not_implemented_error: + container.get_hyperdrive_config(run_config=run_config) + assert 'Parameter search is not implemented' in str(not_implemented_error.value) + # The error should be thrown by + # InnerEye.ML.lightning_container.LightningContainer.get_parameter_search_hyperdrive_config + # since number_of_cross_validation_splits == 0 implies a parameter search hyperdrive config and + # not a cross validation one. + else: + hd_config = container.get_hyperdrive_config(run_config=run_config) + assert isinstance(hd_config, HyperDriveConfig) diff --git a/Tests/ML/test_model_config_base.py b/Tests/ML/test_model_config_base.py index 8fdd19d64..172b3b30f 100644 --- a/Tests/ML/test_model_config_base.py +++ b/Tests/ML/test_model_config_base.py @@ -80,7 +80,6 @@ def test_get_total_number_of_cross_validation_runs() -> None: config = ModelConfigBase(should_validate=False) config.number_of_cross_validation_splits = 2 assert config.perform_cross_validation - assert config.get_total_number_of_cross_validation_runs() == config.number_of_cross_validation_splits def test_config_with_typo() -> None: diff --git a/docs/bring_your_own_model.md b/docs/bring_your_own_model.md index b8b56b26c..a2df9c50e 100644 --- a/docs/bring_your_own_model.md +++ b/docs/bring_your_own_model.md @@ -38,6 +38,16 @@ correctly. If you'd like to have your model defined in a different folder, pleas the `--model_configs_namespace` argument. For example, use `--model_configs_namespace=My.Own.configs` if your model configuration classes reside in folder `My/Own/configs` from the repository root. +### Cross Validation + +If you are doing cross validation you need to ensure that the `LightningDataModule` returned by your container's +`get_data_module` method: +- Needs to take into account the number of cross validation splits, and the cross validation split index when +preparing the data. +- Needs to log val/Loss in its `validation_step` method. +You can find a working example of handling cross validation in the +[HelloContainer](../InnerEye/ML/configs/other/HelloContainer.py) class. + *Example*: ```python from pathlib import Path diff --git a/docs/building_models.md b/docs/building_models.md index 0ec3e7bb6..92a6a5d8f 100755 --- a/docs/building_models.md +++ b/docs/building_models.md @@ -127,9 +127,9 @@ at the same time (provided that the cluster has capacity). This means that a com takes as long as a single training run. To start cross validation, you can either modify the `number_of_cross_validation_splits` property of your model, -or supply it on the command line: Provide all the usual switches, and add `--number_of_cross_validation_splits=N`, +or supply it on the command line: provide all the usual switches, and add `--number_of_cross_validation_splits=N`, for some `N` greater than 1; a value of 5 is typical. This will start a -[HyperDrive run](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters): A parent +[HyperDrive run](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters): a parent AzureML job, with `N` child runs that will execute in parallel. You can see the child runs in the AzureML UI in the "Child Runs" tab.