Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ created.
## Upcoming

### Added
- ([#667](https://github.com/microsoft/InnerEye-DeepLearning/pull/667)) Automatically and linearly scale the learning rate of the SSL encoder to the number of GPUs.
Comment thread
maxilse marked this conversation as resolved.
- ([#689](https://github.com/microsoft/InnerEye-DeepLearning/pull/689)) Show default argument values in help message.
- ([#671](https://github.com/microsoft/InnerEye-DeepLearning/pull/671)) Remove sequence models and unused variables. Simplify README.
- ([#693](https://github.com/microsoft/InnerEye-DeepLearning/pull/693)) Improve instructions for HelloWorld model in AzureML.
Expand Down
12 changes: 11 additions & 1 deletion InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,23 @@ def create_model(self) -> LightningModule:
# For small images like CIFAR, if using a resnet encoder, switch the first conv layer to a 3x3 kernel instead
# of a 7x7 conv layer.
use_7x7_first_conv_in_resnet = False if self.ssl_training_dataset_name.value.startswith("CIFAR") else True

# Rescale the learning rate linearly according to the number of available GPUs, as seen in: https://arxiv.org/abs/1706.02677,
# to avoid a drop in performance
gpus_per_node = self.num_gpus_per_node()
Comment thread
maxilse marked this conversation as resolved.
num_of_total_gpus = self.num_nodes * gpus_per_node
if num_of_total_gpus > 1:
l_rate: float = self.l_rate * num_of_total_gpus
logging.info(f"We found {num_of_total_gpus} GPUs, SSL encoder learning rate has been adjusted from {self.l_rate} to {l_rate}")
self.l_rate = l_rate

if self.ssl_training_type == SSLTrainingType.SimCLR:
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
dataset_name=self.ssl_training_dataset_name.value,
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
num_samples=self.data_module.num_train_samples,
batch_size=self.data_module.batch_size,
gpus=self.num_gpus_per_node(),
gpus=gpus_per_node,
num_nodes=self.num_nodes,
learning_rate=self.l_rate,
max_epochs=self.num_epochs)
Expand Down