diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 1223a10ef6..797b40b566 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -16,9 +16,14 @@ import torch from monai.data import CSVSaver -from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather from monai.utils import ImageMetaKey as Key -from monai.utils import exact_version, issequenceiterable, optional_import +from monai.utils import ( + evenly_divisible_all_gather, + exact_version, + issequenceiterable, + optional_import, + string_list_all_gather, +) idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -126,7 +131,7 @@ def _finalize(self, engine: Engine) -> None: outputs = torch.cat(self._outputs, dim=0) filenames = self._filenames if ws > 1: - outputs = evenly_divisible_all_gather(outputs) + outputs = evenly_divisible_all_gather(outputs, concat=True) filenames = string_list_all_gather(filenames) if len(filenames) == 0: diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 434dd483ed..57653cfc20 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -13,9 +13,8 @@ import torch -from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import do_metric_reduction -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction, evenly_divisible_all_gather, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") @@ -104,7 +103,7 @@ def compute(self) -> Any: ws = idist.get_world_size() if ws > 1 and not self._is_reduced: # all gather across all processes - _scores = evenly_divisible_all_gather(data=_scores) + _scores = evenly_divisible_all_gather(data=_scores, concat=True) self._is_reduced = True # save score of every image into engine.state for other components diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index a4a7de584f..15814fd3cf 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -11,9 +11,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union -from monai.handlers.utils import string_list_all_gather, write_metrics_reports +from monai.handlers.utils import write_metrics_reports from monai.utils import ImageMetaKey as Key -from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import +from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import, string_list_all_gather Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 8011dab8db..7a77f4473e 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -13,9 +13,8 @@ import torch -from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import compute_roc_auc -from monai.utils import Average, exact_version, optional_import +from monai.utils import Average, evenly_divisible_all_gather, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") EpochMetric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "EpochMetric") @@ -78,8 +77,8 @@ def compute(self) -> Any: ws = idist.get_world_size() if ws > 1 and not self._is_reduced: # All gather across all processes - _prediction_tensor = evenly_divisible_all_gather(_prediction_tensor) - _target_tensor = evenly_divisible_all_gather(_target_tensor) + _prediction_tensor = evenly_divisible_all_gather(_prediction_tensor, concat=True) + _target_tensor = evenly_divisible_all_gather(_target_tensor, concat=True) self._is_reduced = True result: torch.Tensor = torch.zeros(1) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index af35eaa953..3777c2f7d4 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import warnings from collections import OrderedDict from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union @@ -66,6 +67,10 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: The input data on different ranks must have exactly same `dtype`. """ + warnings.warn( + "evenly_divisible_all_gather had been moved to monai.utils module, will deprecate this API in MONAI v0.7.", + DeprecationWarning, + ) if not isinstance(data, torch.Tensor): raise ValueError("input data must be PyTorch Tensor.") @@ -95,6 +100,10 @@ def string_list_all_gather(strings: List[str]) -> List[str]: strings: a list of strings to all gather. """ + warnings.warn( + "string_list_all_gather had been moved to monai.utils module, will deprecate this API in MONAI v0.7.", + DeprecationWarning, + ) world_size = idist.get_world_size() if world_size <= 1: return strings diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index d622ce96ae..324f6aa7d4 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -43,8 +43,10 @@ ensure_tuple, ensure_tuple_rep, ensure_tuple_size, + evenly_divisible_all_gather, fall_back_tuple, first, + get_dist_device, get_seed, is_scalar, is_scalar_tensor, @@ -53,6 +55,7 @@ progress_bar, set_determinism, star_zip_with, + string_list_all_gather, zip_with, ) from .module import ( diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bd8e46d8b5..f99fdc6236 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -17,10 +17,11 @@ import warnings from ast import literal_eval from distutils.util import strtobool -from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast import numpy as np import torch +import torch.distributed as dist __all__ = [ "zip_with", @@ -41,6 +42,9 @@ "dtype_numpy_to_torch", "MAX_SEED", "copy_to_device", + "get_dist_device", + "evenly_divisible_all_gather", + "string_list_all_gather", "ImageMetaKey", ] @@ -352,6 +356,95 @@ def copy_to_device( return obj +def get_dist_device(): + """ + Get the expected target device in the distributed data parallel. + For NCCL backend, return GPU device of current process. + For GLOO backend, return CPU. + For any other backends, return None as the default, tensor.to(None) will not change the device. + + """ + if dist.is_initialized(): + backend = dist.get_backend() + if backend == "nccl" and torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + elif backend == "gloo": + return torch.device("cpu") + return None + + +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True): + """ + Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. + The input data of every rank should have the same number of dimensions, only the first dim can be different. + + Args: + data: source tensor to pad and execute all_gather in distributed data parallel. + concat: whether to concat the gathered list to be a Tensor, if False, return a list + of Tensors, similar behavior as torch.distributed.all_gather(). default to True. + + Note: + The input data on different ranks must have exactly same `dtype`. + + """ + if not isinstance(data, torch.Tensor): + raise ValueError("input data must be PyTorch Tensor.") + + world_size = dist.get_world_size() if dist.is_initialized() else 1 + if world_size <= 1: + return data + + device = get_dist_device() + orig_device = data.device + data = data.to(device) + # data of all the ranks must have same number of dimensions + ndims = data.ndimension() + if ndims == 0: + # tensor must have batch dimension + data = data.unsqueeze(0) + # make sure the data is evenly-divisible on multi-GPUs + length: int = data.shape[0] + length_tensor = torch.as_tensor([length], device=device) + all_lens = [torch.zeros_like(length_tensor) for _ in range(world_size)] + dist.all_gather(all_lens, length_tensor) + all_lens_: List[int] = [int(i.item()) for i in all_lens] + + max_len: int = max(all_lens_) + if length < max_len: + size = [max_len - length] + list(data.shape[1:]) + data = torch.cat([data, data.new_full(size, 0)], dim=0) + # all gather across all processes + output = [torch.zeros_like(data) for _ in range(world_size)] + dist.all_gather(output, data) + # remove the padding items, if all the input data doesn't have batch dim, suqeeze the first dim + output = [(o.squeeze(0) if ndims == 0 else o[:l, ...]).to(orig_device) for o, l in zip(output, all_lens_)] + + return torch.cat(output, dim=0) if concat else output + + +def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: + """ + Utility function for distributed data parallel to all gather a list of strings. + Refer to the idea of ignite `all_gather(string)`: + https://github.com/pytorch/ignite/blob/master/ignite/distributed/utils.py#L346. + + Args: + strings: a list of strings to all gather. + delimiter: use the delimiter to join the string list to be a long string, + then all gather across ranks and split to a list. default to "\t". + + """ + world_size = dist.get_world_size() if dist.is_initialized() else 1 + if world_size <= 1: + return strings + + joined = delimiter.join(strings) + gathered = evenly_divisible_all_gather(torch.tensor(bytearray(joined, "utf-8"), dtype=torch.long), concat=False) + gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered] + + return [i for k in gathered for i in k] + + class ImageMetaKey: """ Common key names in the meta data header of images diff --git a/tests/min_tests.py b/tests/min_tests.py index b474128d3c..782ceeb576 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -110,7 +110,6 @@ def run_testsuit(): "test_randtorchvisiond", "test_handler_metrics_saver", "test_handler_metrics_saver_dist", - "test_evenly_divisible_all_gather_dist", "test_handler_classification_saver_dist", "test_deepgrow_transforms", "test_deepgrow_interaction", @@ -125,6 +124,7 @@ def run_testsuit(): "test_cachedataset_persistent_workers", "test_invertd", "test_handler_post_processing", + "test_write_metrics_reports", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index 70dcd7ca6a..bf3bd1bacc 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -14,7 +14,7 @@ import torch import torch.distributed as dist -from monai.handlers.utils import evenly_divisible_all_gather +from monai.utils import evenly_divisible_all_gather from tests.utils import DistCall, DistTestCase @@ -27,15 +27,21 @@ def _run(self): if dist.get_rank() == 0: data1 = torch.tensor([[1, 2], [3, 4]]) data2 = torch.tensor([[1.0, 2.0]]) + data3 = torch.tensor(7) if dist.get_rank() == 1: data1 = torch.tensor([[5, 6]]) data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]]) + data3 = torch.tensor(8) - result1 = evenly_divisible_all_gather(data=data1) + result1 = evenly_divisible_all_gather(data=data1, concat=True) torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]])) - result2 = evenly_divisible_all_gather(data=data2) - torch.testing.assert_allclose(result2, torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + result2 = evenly_divisible_all_gather(data=data2, concat=False) + for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]): + torch.testing.assert_allclose(r, e) + result3 = evenly_divisible_all_gather(data=data3, concat=False) + for r in result3: + self.assertEqual(r.ndimension(), 0) if __name__ == "__main__": diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 40b2ed87de..359f55f3d8 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -20,10 +20,9 @@ from ignite.engine import Engine from monai.handlers import ClassificationSaver -from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion +from tests.utils import DistCall, DistTestCase -@SkipIfBeforePyTorchVersion((1, 7)) class DistributedHandlerClassificationSaver(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) def test_saved_content(self): diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 2f407d5149..0a36a19c66 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -20,11 +20,10 @@ from ignite.engine import Engine, Events from monai.handlers import MetricsSaver -from monai.handlers.utils import evenly_divisible_all_gather -from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion +from monai.utils import evenly_divisible_all_gather +from tests.utils import DistCall, DistTestCase -@SkipIfBeforePyTorchVersion((1, 7)) class DistributedMetricsSaver(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) def test_content(self): @@ -76,9 +75,9 @@ def _save_metrics1(engine): @engine.on(Events.EPOCH_COMPLETED) def _all_gather(engine): scores = engine.state.metric_details["metric3"] - engine.state.metric_details["metric3"] = evenly_divisible_all_gather(data=scores) + engine.state.metric_details["metric3"] = evenly_divisible_all_gather(data=scores, concat=True) scores = engine.state.metric_details["metric4"] - engine.state.metric_details["metric4"] = evenly_divisible_all_gather(data=scores) + engine.state.metric_details["metric4"] = evenly_divisible_all_gather(data=scores, concat=True) metrics_saver.attach(engine) engine.run(data, max_epochs=1)