From 2ae5eeed9c7b8a7da588629c2624de3575ed5812 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 18:23:48 +0800 Subject: [PATCH 01/11] [DLMED] update 2 all_gather APIs Signed-off-by: Nic Ma --- monai/handlers/__init__.py | 8 +- monai/handlers/classification_saver.py | 11 +- monai/handlers/iteration_metric.py | 5 +- monai/handlers/metrics_saver.py | 4 +- monai/handlers/roc_auc.py | 7 +- monai/handlers/utils.py | 77 +------------ monai/utils/__init__.py | 3 + monai/utils/misc.py | 103 +++++++++++++++++- tests/min_tests.py | 1 - .../test_evenly_divisible_all_gather_dist.py | 9 +- tests/test_handler_metrics_saver_dist.py | 6 +- 11 files changed, 132 insertions(+), 102 deletions(-) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 23c0f1bf49..773003805c 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -31,11 +31,5 @@ from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler from .transform_inverter import TransformInverter -from .utils import ( - evenly_divisible_all_gather, - stopping_fn_from_loss, - stopping_fn_from_metric, - string_list_all_gather, - write_metrics_reports, -) +from .utils import stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 1223a10ef6..797b40b566 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -16,9 +16,14 @@ import torch from monai.data import CSVSaver -from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather from monai.utils import ImageMetaKey as Key -from monai.utils import exact_version, issequenceiterable, optional_import +from monai.utils import ( + evenly_divisible_all_gather, + exact_version, + issequenceiterable, + optional_import, + string_list_all_gather, +) idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -126,7 +131,7 @@ def _finalize(self, engine: Engine) -> None: outputs = torch.cat(self._outputs, dim=0) filenames = self._filenames if ws > 1: - outputs = evenly_divisible_all_gather(outputs) + outputs = evenly_divisible_all_gather(outputs, concat=True) filenames = string_list_all_gather(filenames) if len(filenames) == 0: diff --git a/monai/handlers/iteration_metric.py b/monai/handlers/iteration_metric.py index 434dd483ed..57653cfc20 100644 --- a/monai/handlers/iteration_metric.py +++ b/monai/handlers/iteration_metric.py @@ -13,9 +13,8 @@ import torch -from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import do_metric_reduction -from monai.utils import MetricReduction, exact_version, optional_import +from monai.utils import MetricReduction, evenly_divisible_all_gather, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric") @@ -104,7 +103,7 @@ def compute(self) -> Any: ws = idist.get_world_size() if ws > 1 and not self._is_reduced: # all gather across all processes - _scores = evenly_divisible_all_gather(data=_scores) + _scores = evenly_divisible_all_gather(data=_scores, concat=True) self._is_reduced = True # save score of every image into engine.state for other components diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index a4a7de584f..1e5314b271 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -11,9 +11,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union -from monai.handlers.utils import string_list_all_gather, write_metrics_reports +from monai.handlers.utils import write_metrics_reports from monai.utils import ImageMetaKey as Key -from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import +from monai.utils import ensure_tuple, exact_version, issequenceiterable, string_list_all_gather, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") diff --git a/monai/handlers/roc_auc.py b/monai/handlers/roc_auc.py index 8011dab8db..7a77f4473e 100644 --- a/monai/handlers/roc_auc.py +++ b/monai/handlers/roc_auc.py @@ -13,9 +13,8 @@ import torch -from monai.handlers.utils import evenly_divisible_all_gather from monai.metrics import compute_roc_auc -from monai.utils import Average, exact_version, optional_import +from monai.utils import Average, evenly_divisible_all_gather, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") EpochMetric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "EpochMetric") @@ -78,8 +77,8 @@ def compute(self) -> Any: ws = idist.get_world_size() if ws > 1 and not self._is_reduced: # All gather across all processes - _prediction_tensor = evenly_divisible_all_gather(_prediction_tensor) - _target_tensor = evenly_divisible_all_gather(_target_tensor) + _prediction_tensor = evenly_divisible_all_gather(_prediction_tensor, concat=True) + _target_tensor = evenly_divisible_all_gather(_target_tensor, concat=True) self._is_reduced = True result: torch.Tensor = torch.zeros(1) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index af35eaa953..250ae2b438 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -11,12 +11,12 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union import numpy as np import torch -from monai.utils import ensure_tuple, exact_version, get_torch_version_tuple, optional_import +from monai.utils import ensure_tuple, exact_version, optional_import idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") if TYPE_CHECKING: @@ -24,13 +24,7 @@ else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") -__all__ = [ - "stopping_fn_from_metric", - "stopping_fn_from_loss", - "evenly_divisible_all_gather", - "string_list_all_gather", - "write_metrics_reports", -] +__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "write_metrics_reports"] def stopping_fn_from_metric(metric_name: str): @@ -55,71 +49,6 @@ def stopping_fn(engine: Engine): return stopping_fn -def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: - """ - Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. - - Args: - data: source tensor to pad and execute all_gather in distributed data parallel. - - Note: - The input data on different ranks must have exactly same `dtype`. - - """ - if not isinstance(data, torch.Tensor): - raise ValueError("input data must be PyTorch Tensor.") - - if idist.get_world_size() <= 1: - return data - - # make sure the data is evenly-divisible on multi-GPUs - length = data.shape[0] - all_lens = idist.all_gather(length) - max_len = max(all_lens) - if length < max_len: - size = [max_len - length] + list(data.shape[1:]) - data = torch.cat([data, data.new_full(size, 0)], dim=0) - # all gather across all processes - data = idist.all_gather(data) - # delete the padding NaN items - return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) - - -def string_list_all_gather(strings: List[str]) -> List[str]: - """ - Utility function for distributed data parallel to all gather a list of strings. - Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024: - https://github.com/pytorch/ignite/blob/master/ignite/distributed/comp_models/base.py#L92 - - Args: - strings: a list of strings to all gather. - - """ - world_size = idist.get_world_size() - if world_size <= 1: - return strings - - result: List[List[str]] = [[] for _ in range(world_size)] - # get length of strings - length = len(strings) - all_lens = idist.all_gather(length) - max_len = max(all_lens) - # pad the item to make sure the same length - if length < max_len: - strings = strings + ["" for _ in range(max_len - length)] - - if get_torch_version_tuple() > (1, 6, 0): - for s in strings: - gathered = idist.all_gather(s) - for i, g in enumerate(gathered): - if len(g) > 0: - result[i].append(g) - else: - raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") - - return [i for k in result for i in k] - - def write_metrics_reports( save_dir: str, images: Optional[Sequence[str]], diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index d622ce96ae..324f6aa7d4 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -43,8 +43,10 @@ ensure_tuple, ensure_tuple_rep, ensure_tuple_size, + evenly_divisible_all_gather, fall_back_tuple, first, + get_dist_device, get_seed, is_scalar, is_scalar_tensor, @@ -53,6 +55,7 @@ progress_bar, set_determinism, star_zip_with, + string_list_all_gather, zip_with, ) from .module import ( diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bd8e46d8b5..641e76e072 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -17,10 +17,11 @@ import warnings from ast import literal_eval from distutils.util import strtobool -from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast import numpy as np import torch +import torch.distributed as dist __all__ = [ "zip_with", @@ -41,6 +42,9 @@ "dtype_numpy_to_torch", "MAX_SEED", "copy_to_device", + "get_dist_device", + "evenly_divisible_all_gather", + "string_list_all_gather", "ImageMetaKey", ] @@ -352,6 +356,103 @@ def copy_to_device( return obj +def get_dist_device(): + """ + Get the expected target device in the distributed data parallel. + For NCCL backend, return GPU device of current process. + For GLOO backend, return CPU. + For any other backends, return None as the default, tensor.to(None) will not change the device. + + """ + if dist.is_initialized(): + backend = dist.get_backend() + if backend == "nccl" and torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + elif backend == "gloo": + return torch.device("cpu") + return None + + +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False) -> torch.Tensor: + """ + Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. + + Args: + data: source tensor to pad and execute all_gather in distributed data parallel. + concat: whether to concat the gathered list to be a Tensor, if False, return a list + of Tensors, same behavior as torch.distributed.all_gather(). default to False. + + Note: + The input data on different ranks must have exactly same `dtype`. + + """ + if not isinstance(data, torch.Tensor): + raise ValueError("input data must be PyTorch Tensor.") + + world_size = dist.get_world_size() if dist.is_initialized() else 1 + if world_size <= 1: + return data + + device = get_dist_device() + orig_device = data.device + data = data.to(device) + # tensor must have batch dimension + if data.ndimension() == 0: + data = data.unsqueeze(0) + # make sure the data is evenly-divisible on multi-GPUs + length = data.shape[0] + length_tensor = torch.as_tensor([length], device=device) + all_lens = [torch.zeros_like(length_tensor) for _ in range(world_size)] + dist.all_gather(all_lens, length_tensor) + + max_len = max(all_lens).item() + if length < max_len: + size = [max_len - length] + list(data.shape[1:]) + data = torch.cat([data, data.new_full(size, 0)], dim=0) + # all gather across all processes + output = [torch.zeros_like(data) for _ in range(world_size)] + dist.all_gather(output, data) + # remove the padding items + output = [o[:l.item(), ...].to(orig_device) for o, l in zip(output, all_lens)] + + return torch.cat(output, dim=0) if concat else output + + +def string_list_all_gather(strings: List[str]) -> List[str]: + """ + Utility function for distributed data parallel to all gather a list of strings. + Refer to the idea of ignite `all_gather(string)`. + + Args: + strings: a list of strings to all gather. + + """ + world_size = dist.get_world_size() if dist.is_initialized() else 1 + if world_size <= 1: + return strings + + result: List[List[str]] = [[] for _ in range(world_size)] + # get length of strings + length = len(strings) + length_tensor = torch.as_tensor([length], device=get_dist_device()) + dist.all_reduce(length_tensor, op=dist.ReduceOp.MAX) + max_len = length_tensor.item() + + # pad the item to make sure the same length + if length < max_len: + strings = strings + ["" for _ in range(max_len - length)] + + for s in strings: + gathered = evenly_divisible_all_gather(data=torch.tensor(bytearray(s, "utf-8"), dtype=torch.long), concat=False) + gathered = [bytearray(g.tolist()).decode("utf-8") for g in gathered] + + for i, g in enumerate(gathered): + if len(g) > 0: + result[i].append(g) + + return [i for k in result for i in k] + + class ImageMetaKey: """ Common key names in the meta data header of images diff --git a/tests/min_tests.py b/tests/min_tests.py index b474128d3c..a0c140d157 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -110,7 +110,6 @@ def run_testsuit(): "test_randtorchvisiond", "test_handler_metrics_saver", "test_handler_metrics_saver_dist", - "test_evenly_divisible_all_gather_dist", "test_handler_classification_saver_dist", "test_deepgrow_transforms", "test_deepgrow_interaction", diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index 70dcd7ca6a..abd8dddd9f 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -14,7 +14,7 @@ import torch import torch.distributed as dist -from monai.handlers.utils import evenly_divisible_all_gather +from monai.utils import evenly_divisible_all_gather from tests.utils import DistCall, DistTestCase @@ -32,10 +32,11 @@ def _run(self): data1 = torch.tensor([[5, 6]]) data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]]) - result1 = evenly_divisible_all_gather(data=data1) + result1 = evenly_divisible_all_gather(data=data1, concat=True) torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]])) - result2 = evenly_divisible_all_gather(data=data2) - torch.testing.assert_allclose(result2, torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + result2 = evenly_divisible_all_gather(data=data2, concat=False) + for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]): + torch.testing.assert_allclose(r, e) if __name__ == "__main__": diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 2f407d5149..0b0082a768 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -20,7 +20,7 @@ from ignite.engine import Engine, Events from monai.handlers import MetricsSaver -from monai.handlers.utils import evenly_divisible_all_gather +from monai.utils import evenly_divisible_all_gather from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion @@ -76,9 +76,9 @@ def _save_metrics1(engine): @engine.on(Events.EPOCH_COMPLETED) def _all_gather(engine): scores = engine.state.metric_details["metric3"] - engine.state.metric_details["metric3"] = evenly_divisible_all_gather(data=scores) + engine.state.metric_details["metric3"] = evenly_divisible_all_gather(data=scores, concat=True) scores = engine.state.metric_details["metric4"] - engine.state.metric_details["metric4"] = evenly_divisible_all_gather(data=scores) + engine.state.metric_details["metric4"] = evenly_divisible_all_gather(data=scores, concat=True) metrics_saver.attach(engine) engine.run(data, max_epochs=1) From a0802ce02a80b6d255934f8dbc4133c2a6790b60 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 18:49:59 +0800 Subject: [PATCH 02/11] [DLMED] remove pytorch version limit Signed-off-by: Nic Ma --- tests/test_handler_classification_saver_dist.py | 3 +-- tests/test_handler_metrics_saver_dist.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 40b2ed87de..359f55f3d8 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -20,10 +20,9 @@ from ignite.engine import Engine from monai.handlers import ClassificationSaver -from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion +from tests.utils import DistCall, DistTestCase -@SkipIfBeforePyTorchVersion((1, 7)) class DistributedHandlerClassificationSaver(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) def test_saved_content(self): diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 0b0082a768..0a36a19c66 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -21,10 +21,9 @@ from monai.handlers import MetricsSaver from monai.utils import evenly_divisible_all_gather -from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion +from tests.utils import DistCall, DistTestCase -@SkipIfBeforePyTorchVersion((1, 7)) class DistributedMetricsSaver(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) def test_content(self): From e6fef3ed111a1e0f80630b7d80d4812c11d48441 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 18:58:39 +0800 Subject: [PATCH 03/11] [DLMED] optimize perf by joining strings Signed-off-by: Nic Ma --- monai/utils/misc.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 641e76e072..4d1953224f 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -418,40 +418,26 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False) -> tor return torch.cat(output, dim=0) if concat else output -def string_list_all_gather(strings: List[str]) -> List[str]: +def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: """ Utility function for distributed data parallel to all gather a list of strings. Refer to the idea of ignite `all_gather(string)`. Args: strings: a list of strings to all gather. + delimiter: use the delimiter to join the string list to be a long string, + then all gather across ranks and split to a list. default to "\t". """ world_size = dist.get_world_size() if dist.is_initialized() else 1 if world_size <= 1: return strings - result: List[List[str]] = [[] for _ in range(world_size)] - # get length of strings - length = len(strings) - length_tensor = torch.as_tensor([length], device=get_dist_device()) - dist.all_reduce(length_tensor, op=dist.ReduceOp.MAX) - max_len = length_tensor.item() - - # pad the item to make sure the same length - if length < max_len: - strings = strings + ["" for _ in range(max_len - length)] - - for s in strings: - gathered = evenly_divisible_all_gather(data=torch.tensor(bytearray(s, "utf-8"), dtype=torch.long), concat=False) - gathered = [bytearray(g.tolist()).decode("utf-8") for g in gathered] - - for i, g in enumerate(gathered): - if len(g) > 0: - result[i].append(g) - - return [i for k in result for i in k] + joined = delimiter.join(strings) + gathered = evenly_divisible_all_gather(torch.tensor(bytearray(joined, "utf-8"), dtype=torch.long), concat=False) + gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered] + return [i for k in gathered for i in k] class ImageMetaKey: """ From f6c766a29a53305be5938e12561ac84a9e42a3d8 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 8 Jun 2021 11:02:36 +0000 Subject: [PATCH 04/11] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/metrics_saver.py | 2 +- monai/utils/misc.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 1e5314b271..15814fd3cf 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -13,7 +13,7 @@ from monai.handlers.utils import write_metrics_reports from monai.utils import ImageMetaKey as Key -from monai.utils import ensure_tuple, exact_version, issequenceiterable, string_list_all_gather, optional_import +from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import, string_list_all_gather Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4d1953224f..3482adf74e 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -413,7 +413,7 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False) -> tor output = [torch.zeros_like(data) for _ in range(world_size)] dist.all_gather(output, data) # remove the padding items - output = [o[:l.item(), ...].to(orig_device) for o, l in zip(output, all_lens)] + output = [o[: l.item(), ...].to(orig_device) for o, l in zip(output, all_lens)] return torch.cat(output, dim=0) if concat else output @@ -439,6 +439,7 @@ def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[st return [i for k in gathered for i in k] + class ImageMetaKey: """ Common key names in the meta data header of images From 8220fc7bd67d9783eb52e155f7971f8ab9c55901 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 19:07:25 +0800 Subject: [PATCH 05/11] [DLMED] remove unnecessary import Signed-off-by: Nic Ma --- monai/handlers/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 250ae2b438..bc4dde4abc 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -18,7 +18,6 @@ from monai.utils import ensure_tuple, exact_version, optional_import -idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine else: From 128bb39c927da8c02776807f43fae8827fcb6be8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 20:25:13 +0800 Subject: [PATCH 06/11] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/utils/misc.py | 10 ++++++---- tests/min_tests.py | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 3482adf74e..20696c0911 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -373,7 +373,8 @@ def get_dist_device(): return None -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False) -> torch.Tensor: +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): + """ Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. @@ -400,12 +401,13 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False) -> tor if data.ndimension() == 0: data = data.unsqueeze(0) # make sure the data is evenly-divisible on multi-GPUs - length = data.shape[0] + length: int = data.shape[0] length_tensor = torch.as_tensor([length], device=device) all_lens = [torch.zeros_like(length_tensor) for _ in range(world_size)] dist.all_gather(all_lens, length_tensor) + all_lens_: List[int] = [int(i.item()) for i in all_lens] - max_len = max(all_lens).item() + max_len: int = max(all_lens_) if length < max_len: size = [max_len - length] + list(data.shape[1:]) data = torch.cat([data, data.new_full(size, 0)], dim=0) @@ -413,7 +415,7 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False) -> tor output = [torch.zeros_like(data) for _ in range(world_size)] dist.all_gather(output, data) # remove the padding items - output = [o[: l.item(), ...].to(orig_device) for o, l in zip(output, all_lens)] + output = [o[: l, ...].to(orig_device) for o, l in zip(output, all_lens_)] return torch.cat(output, dim=0) if concat else output diff --git a/tests/min_tests.py b/tests/min_tests.py index a0c140d157..782ceeb576 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -124,6 +124,7 @@ def run_testsuit(): "test_cachedataset_persistent_workers", "test_invertd", "test_handler_post_processing", + "test_write_metrics_reports", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 22edaac14a2ae9c2e3e39dcf281df22922635e58 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 8 Jun 2021 12:29:20 +0000 Subject: [PATCH 07/11] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/utils/misc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 20696c0911..23e72a94f6 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -374,7 +374,6 @@ def get_dist_device(): def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): - """ Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. @@ -415,7 +414,7 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): output = [torch.zeros_like(data) for _ in range(world_size)] dist.all_gather(output, data) # remove the padding items - output = [o[: l, ...].to(orig_device) for o, l in zip(output, all_lens_)] + output = [o[:l, ...].to(orig_device) for o, l in zip(output, all_lens_)] return torch.cat(output, dim=0) if concat else output From 95f8ddfe3197b829efb24b1e713db47b61721f5c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 22:35:30 +0800 Subject: [PATCH 08/11] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/utils/misc.py | 14 +++++++++----- tests/test_evenly_divisible_all_gather_dist.py | 5 +++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 23e72a94f6..4e66f8265b 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -376,6 +376,7 @@ def get_dist_device(): def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): """ Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. + The input data of every rank should have the same number of dimensions, only the first dim can be different. Args: data: source tensor to pad and execute all_gather in distributed data parallel. @@ -396,8 +397,10 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): device = get_dist_device() orig_device = data.device data = data.to(device) - # tensor must have batch dimension - if data.ndimension() == 0: + # data of all the ranks must have same number of dimensions + ndims = data.ndimension() + if ndims == 0: + # tensor must have batch dimension data = data.unsqueeze(0) # make sure the data is evenly-divisible on multi-GPUs length: int = data.shape[0] @@ -413,8 +416,8 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): # all gather across all processes output = [torch.zeros_like(data) for _ in range(world_size)] dist.all_gather(output, data) - # remove the padding items - output = [o[:l, ...].to(orig_device) for o, l in zip(output, all_lens_)] + # remove the padding items, if all the input data doesn't have batch dim, suqeeze the first dim + output = [(o.squeeze(0) if ndims == 0 else o[:l, ...]).to(orig_device) for o, l in zip(output, all_lens_)] return torch.cat(output, dim=0) if concat else output @@ -422,7 +425,8 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: """ Utility function for distributed data parallel to all gather a list of strings. - Refer to the idea of ignite `all_gather(string)`. + Refer to the idea of ignite `all_gather(string)`: + https://github.com/pytorch/ignite/blob/master/ignite/distributed/utils.py#L346. Args: strings: a list of strings to all gather. diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py index abd8dddd9f..bf3bd1bacc 100644 --- a/tests/test_evenly_divisible_all_gather_dist.py +++ b/tests/test_evenly_divisible_all_gather_dist.py @@ -27,16 +27,21 @@ def _run(self): if dist.get_rank() == 0: data1 = torch.tensor([[1, 2], [3, 4]]) data2 = torch.tensor([[1.0, 2.0]]) + data3 = torch.tensor(7) if dist.get_rank() == 1: data1 = torch.tensor([[5, 6]]) data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]]) + data3 = torch.tensor(8) result1 = evenly_divisible_all_gather(data=data1, concat=True) torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]])) result2 = evenly_divisible_all_gather(data=data2, concat=False) for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]): torch.testing.assert_allclose(r, e) + result3 = evenly_divisible_all_gather(data=data3, concat=False) + for r in result3: + self.assertEqual(r.ndimension(), 0) if __name__ == "__main__": From cbebdffc19ad8808bdc2684911898284b876870d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 23:30:43 +0800 Subject: [PATCH 09/11] [DLMED] add deprecated warning Signed-off-by: Nic Ma --- monai/handlers/__init__.py | 8 +++- monai/handlers/utils.py | 87 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 773003805c..23c0f1bf49 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -31,5 +31,11 @@ from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler from .transform_inverter import TransformInverter -from .utils import stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports +from .utils import ( + evenly_divisible_all_gather, + stopping_fn_from_loss, + stopping_fn_from_metric, + string_list_all_gather, + write_metrics_reports, +) from .validation_handler import ValidationHandler diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index bc4dde4abc..1804f88366 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -11,19 +11,27 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union +import warnings import numpy as np import torch -from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils import ensure_tuple, exact_version, get_torch_version_tuple, optional_import +idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed") if TYPE_CHECKING: from ignite.engine import Engine else: Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") -__all__ = ["stopping_fn_from_metric", "stopping_fn_from_loss", "write_metrics_reports"] +__all__ = [ + "stopping_fn_from_metric", + "stopping_fn_from_loss", + "evenly_divisible_all_gather", + "string_list_all_gather", + "write_metrics_reports", +] def stopping_fn_from_metric(metric_name: str): @@ -48,6 +56,79 @@ def stopping_fn(engine: Engine): return stopping_fn +def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: + """ + Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. + + Args: + data: source tensor to pad and execute all_gather in distributed data parallel. + + Note: + The input data on different ranks must have exactly same `dtype`. + + """ + warnings.warn( + "evenly_divisible_all_gather had been moved to monai.utils module, will deprecate this API in MONAI v0.7.", + DeprecationWarning, + ) + if not isinstance(data, torch.Tensor): + raise ValueError("input data must be PyTorch Tensor.") + + if idist.get_world_size() <= 1: + return data + + # make sure the data is evenly-divisible on multi-GPUs + length = data.shape[0] + all_lens = idist.all_gather(length) + max_len = max(all_lens) + if length < max_len: + size = [max_len - length] + list(data.shape[1:]) + data = torch.cat([data, data.new_full(size, 0)], dim=0) + # all gather across all processes + data = idist.all_gather(data) + # delete the padding NaN items + return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) + + +def string_list_all_gather(strings: List[str]) -> List[str]: + """ + Utility function for distributed data parallel to all gather a list of strings. + Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024: + https://github.com/pytorch/ignite/blob/master/ignite/distributed/comp_models/base.py#L92 + + Args: + strings: a list of strings to all gather. + + """ + warnings.warn( + "string_list_all_gather had been moved to monai.utils module, will deprecate this API in MONAI v0.7.", + DeprecationWarning, + ) + world_size = idist.get_world_size() + if world_size <= 1: + return strings + + result: List[List[str]] = [[] for _ in range(world_size)] + # get length of strings + length = len(strings) + all_lens = idist.all_gather(length) + max_len = max(all_lens) + # pad the item to make sure the same length + if length < max_len: + strings = strings + ["" for _ in range(max_len - length)] + + if get_torch_version_tuple() > (1, 6, 0): + for s in strings: + gathered = idist.all_gather(s) + for i, g in enumerate(gathered): + if len(g) > 0: + result[i].append(g) + else: + raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") + + return [i for k in result for i in k] + + def write_metrics_reports( save_dir: str, images: Optional[Sequence[str]], From c7f7455f47bbefdf8503127f65c7188d03008bff Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 8 Jun 2021 15:42:45 +0000 Subject: [PATCH 10/11] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 1804f88366..3777c2f7d4 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -10,10 +10,10 @@ # limitations under the License. import os +import warnings from collections import OrderedDict from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union -import warnings import numpy as np import torch From 7bee08d17dc0b1de7ef71a436e5bdbc398e73475 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 8 Jun 2021 23:49:47 +0800 Subject: [PATCH 11/11] [DLMED] change concat default to True Signed-off-by: Nic Ma --- monai/utils/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 4e66f8265b..f99fdc6236 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -373,7 +373,7 @@ def get_dist_device(): return None -def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): +def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True): """ Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather. The input data of every rank should have the same number of dimensions, only the first dim can be different. @@ -381,7 +381,7 @@ def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = False): Args: data: source tensor to pad and execute all_gather in distributed data parallel. concat: whether to concat the gathered list to be a Tensor, if False, return a list - of Tensors, same behavior as torch.distributed.all_gather(). default to False. + of Tensors, similar behavior as torch.distributed.all_gather(). default to True. Note: The input data on different ranks must have exactly same `dtype`.