Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions monai/engines/multi_gpu_supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def create_multigpu_supervised_trainer(
tuple of tensors `(batch_x, batch_y)`.
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use
the first device as output device.
distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain
only 1 GPU or CPU for current distributed rank.

Returns:
Engine: a trainer engine with supervised update function.
Expand All @@ -87,6 +87,8 @@ def create_multigpu_supervised_trainer(

devices_ = get_devices_spec(devices)
if distributed:
if len(devices_) > 1:
raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.")
net = DistributedDataParallel(net, device_ids=devices_)
elif len(devices_) > 1:
net = DataParallel(net)
Expand Down Expand Up @@ -122,8 +124,8 @@ def create_multigpu_supervised_evaluator(
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)`
which fits output expected by metrics. If you change it you should use `output_transform` in metrics.
distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use
the first device as output device.
distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain
only 1 GPU or CPU for current distributed rank.

Note:
`engine.state.output` for this engine is defined by `output_transform` parameter and is
Expand All @@ -137,6 +139,10 @@ def create_multigpu_supervised_evaluator(

if distributed:
net = DistributedDataParallel(net, device_ids=devices_)
if len(devices_) > 1:
raise ValueError(
f"for distributed evaluation, `devices` must contain only 1 GPU or CPU, but got {devices_}."
)
elif len(devices_) > 1:
net = DataParallel(net)

Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def run_testsuit():
"test_zoom_affine",
"test_zoomd",
"test_prepare_batch_default_dist",
"test_parallel_execution_dist",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
45 changes: 45 additions & 0 deletions tests/test_parallel_execution_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
import torch.distributed as dist

from monai.engines import create_multigpu_supervised_trainer
from tests.utils import DistCall, DistTestCase, skip_if_no_cuda


def fake_loss(y_pred, y):
return (y_pred[0] + y).sum()


def fake_data_stream():
while True:
yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64))


class DistributedTestParallelExecution(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
@skip_if_no_cuda
def test_distributed(self):
device = torch.device(f"cuda:{dist.get_rank()}")
net = torch.nn.Conv2d(1, 1, 3, padding=1).to(device)
opt = torch.optim.Adam(net.parameters(), 1e-3)

trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [device], distributed=True)
trainer.run(fake_data_stream(), 2, 2)
Comment thread
Nic-Ma marked this conversation as resolved.
# assert the trainer output is loss value
self.assertTrue(isinstance(trainer.state.output, float))


if __name__ == "__main__":
unittest.main()