diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 7c59b670b7..0433617649 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -74,8 +74,8 @@ def create_multigpu_supervised_trainer( tuple of tensors `(batch_x, batch_y)`. output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Returns: Engine: a trainer engine with supervised update function. @@ -87,6 +87,8 @@ def create_multigpu_supervised_trainer( devices_ = get_devices_spec(devices) if distributed: + if len(devices_) > 1: + raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.") net = DistributedDataParallel(net, device_ids=devices_) elif len(devices_) > 1: net = DataParallel(net) @@ -122,8 +124,8 @@ def create_multigpu_supervised_evaluator( output_transform: function that receives 'x', 'y', 'y_pred' and returns value to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits output expected by metrics. If you change it you should use `output_transform` in metrics. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is @@ -137,6 +139,10 @@ def create_multigpu_supervised_evaluator( if distributed: net = DistributedDataParallel(net, device_ids=devices_) + if len(devices_) > 1: + raise ValueError( + f"for distributed evaluation, `devices` must contain only 1 GPU or CPU, but got {devices_}." + ) elif len(devices_) > 1: net = DataParallel(net) diff --git a/tests/min_tests.py b/tests/min_tests.py index 783ab370c1..090167c4b1 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -155,6 +155,7 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_prepare_batch_default_dist", + "test_parallel_execution_dist", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_parallel_execution_dist.py b/tests/test_parallel_execution_dist.py new file mode 100644 index 0000000000..f067b71d14 --- /dev/null +++ b/tests/test_parallel_execution_dist.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.engines import create_multigpu_supervised_trainer +from tests.utils import DistCall, DistTestCase, skip_if_no_cuda + + +def fake_loss(y_pred, y): + return (y_pred[0] + y).sum() + + +def fake_data_stream(): + while True: + yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64)) + + +class DistributedTestParallelExecution(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + @skip_if_no_cuda + def test_distributed(self): + device = torch.device(f"cuda:{dist.get_rank()}") + net = torch.nn.Conv2d(1, 1, 3, padding=1).to(device) + opt = torch.optim.Adam(net.parameters(), 1e-3) + + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [device], distributed=True) + trainer.run(fake_data_stream(), 2, 2) + # assert the trainer output is loss value + self.assertTrue(isinstance(trainer.state.output, float)) + + +if __name__ == "__main__": + unittest.main()