Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
import torch

from monai.data import CSVSaver
from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather
from monai.utils import ImageMetaKey as Key
from monai.utils import exact_version, issequenceiterable, optional_import
from monai.utils import (
evenly_divisible_all_gather,
exact_version,
issequenceiterable,
optional_import,
string_list_all_gather,
)

idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
Expand Down Expand Up @@ -126,7 +131,7 @@ def _finalize(self, engine: Engine) -> None:
outputs = torch.cat(self._outputs, dim=0)
filenames = self._filenames
if ws > 1:
outputs = evenly_divisible_all_gather(outputs)
outputs = evenly_divisible_all_gather(outputs, concat=True)
filenames = string_list_all_gather(filenames)

if len(filenames) == 0:
Expand Down
5 changes: 2 additions & 3 deletions monai/handlers/iteration_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@

import torch

from monai.handlers.utils import evenly_divisible_all_gather
from monai.metrics import do_metric_reduction
from monai.utils import MetricReduction, exact_version, optional_import
from monai.utils import MetricReduction, evenly_divisible_all_gather, exact_version, optional_import

idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
Metric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "Metric")
Expand Down Expand Up @@ -104,7 +103,7 @@ def compute(self) -> Any:
ws = idist.get_world_size()
if ws > 1 and not self._is_reduced:
# all gather across all processes
_scores = evenly_divisible_all_gather(data=_scores)
_scores = evenly_divisible_all_gather(data=_scores, concat=True)
self._is_reduced = True

# save score of every image into engine.state for other components
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union

from monai.handlers.utils import string_list_all_gather, write_metrics_reports
from monai.handlers.utils import write_metrics_reports
from monai.utils import ImageMetaKey as Key
from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import
from monai.utils import ensure_tuple, exact_version, issequenceiterable, optional_import, string_list_all_gather

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
Expand Down
7 changes: 3 additions & 4 deletions monai/handlers/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@

import torch

from monai.handlers.utils import evenly_divisible_all_gather
from monai.metrics import compute_roc_auc
from monai.utils import Average, exact_version, optional_import
from monai.utils import Average, evenly_divisible_all_gather, exact_version, optional_import

idist, _ = optional_import("ignite", "0.4.4", exact_version, "distributed")
EpochMetric, _ = optional_import("ignite.metrics", "0.4.4", exact_version, "EpochMetric")
Expand Down Expand Up @@ -78,8 +77,8 @@ def compute(self) -> Any:
ws = idist.get_world_size()
if ws > 1 and not self._is_reduced:
# All gather across all processes
_prediction_tensor = evenly_divisible_all_gather(_prediction_tensor)
_target_tensor = evenly_divisible_all_gather(_target_tensor)
_prediction_tensor = evenly_divisible_all_gather(_prediction_tensor, concat=True)
_target_tensor = evenly_divisible_all_gather(_target_tensor, concat=True)
self._is_reduced = True

result: torch.Tensor = torch.zeros(1)
Expand Down
9 changes: 9 additions & 0 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import os
import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union

Expand Down Expand Up @@ -66,6 +67,10 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor:
The input data on different ranks must have exactly same `dtype`.

"""
warnings.warn(
"evenly_divisible_all_gather had been moved to monai.utils module, will deprecate this API in MONAI v0.7.",
DeprecationWarning,
)
if not isinstance(data, torch.Tensor):
raise ValueError("input data must be PyTorch Tensor.")

Expand Down Expand Up @@ -95,6 +100,10 @@ def string_list_all_gather(strings: List[str]) -> List[str]:
strings: a list of strings to all gather.

"""
warnings.warn(
"string_list_all_gather had been moved to monai.utils module, will deprecate this API in MONAI v0.7.",
DeprecationWarning,
)
world_size = idist.get_world_size()
if world_size <= 1:
return strings
Expand Down
3 changes: 3 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
ensure_tuple,
ensure_tuple_rep,
ensure_tuple_size,
evenly_divisible_all_gather,
fall_back_tuple,
first,
get_dist_device,
get_seed,
is_scalar,
is_scalar_tensor,
Expand All @@ -53,6 +55,7 @@
progress_bar,
set_determinism,
star_zip_with,
string_list_all_gather,
zip_with,
)
from .module import (
Expand Down
95 changes: 94 additions & 1 deletion monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import warnings
from ast import literal_eval
from distutils.util import strtobool
from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
import torch.distributed as dist

__all__ = [
"zip_with",
Expand All @@ -41,6 +42,9 @@
"dtype_numpy_to_torch",
"MAX_SEED",
"copy_to_device",
"get_dist_device",
"evenly_divisible_all_gather",
"string_list_all_gather",
"ImageMetaKey",
]

Expand Down Expand Up @@ -352,6 +356,95 @@ def copy_to_device(
return obj


Comment thread
wyli marked this conversation as resolved.
def get_dist_device():
"""
Get the expected target device in the distributed data parallel.
For NCCL backend, return GPU device of current process.
For GLOO backend, return CPU.
For any other backends, return None as the default, tensor.to(None) will not change the device.

"""
if dist.is_initialized():
backend = dist.get_backend()
if backend == "nccl" and torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif backend == "gloo":
return torch.device("cpu")
return None


def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True):
"""
Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather.
The input data of every rank should have the same number of dimensions, only the first dim can be different.

Args:
data: source tensor to pad and execute all_gather in distributed data parallel.
concat: whether to concat the gathered list to be a Tensor, if False, return a list
of Tensors, similar behavior as torch.distributed.all_gather(). default to True.

Note:
The input data on different ranks must have exactly same `dtype`.

"""
if not isinstance(data, torch.Tensor):
raise ValueError("input data must be PyTorch Tensor.")

world_size = dist.get_world_size() if dist.is_initialized() else 1
if world_size <= 1:
return data

device = get_dist_device()
orig_device = data.device
data = data.to(device)
# data of all the ranks must have same number of dimensions
ndims = data.ndimension()
if ndims == 0:
# tensor must have batch dimension
data = data.unsqueeze(0)
Comment thread
Nic-Ma marked this conversation as resolved.
# make sure the data is evenly-divisible on multi-GPUs
length: int = data.shape[0]
length_tensor = torch.as_tensor([length], device=device)
all_lens = [torch.zeros_like(length_tensor) for _ in range(world_size)]
dist.all_gather(all_lens, length_tensor)
all_lens_: List[int] = [int(i.item()) for i in all_lens]

max_len: int = max(all_lens_)
if length < max_len:
size = [max_len - length] + list(data.shape[1:])
data = torch.cat([data, data.new_full(size, 0)], dim=0)
# all gather across all processes
output = [torch.zeros_like(data) for _ in range(world_size)]
dist.all_gather(output, data)
# remove the padding items, if all the input data doesn't have batch dim, suqeeze the first dim
output = [(o.squeeze(0) if ndims == 0 else o[:l, ...]).to(orig_device) for o, l in zip(output, all_lens_)]

return torch.cat(output, dim=0) if concat else output


def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]:
"""
Utility function for distributed data parallel to all gather a list of strings.
Refer to the idea of ignite `all_gather(string)`:
https://github.com/pytorch/ignite/blob/master/ignite/distributed/utils.py#L346.

Args:
strings: a list of strings to all gather.
delimiter: use the delimiter to join the string list to be a long string,
then all gather across ranks and split to a list. default to "\t".

"""
world_size = dist.get_world_size() if dist.is_initialized() else 1
if world_size <= 1:
return strings

joined = delimiter.join(strings)
gathered = evenly_divisible_all_gather(torch.tensor(bytearray(joined, "utf-8"), dtype=torch.long), concat=False)
gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered]

return [i for k in gathered for i in k]


class ImageMetaKey:
"""
Common key names in the meta data header of images
Expand Down
2 changes: 1 addition & 1 deletion tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def run_testsuit():
"test_randtorchvisiond",
"test_handler_metrics_saver",
"test_handler_metrics_saver_dist",
"test_evenly_divisible_all_gather_dist",
"test_handler_classification_saver_dist",
"test_deepgrow_transforms",
"test_deepgrow_interaction",
Expand All @@ -125,6 +124,7 @@ def run_testsuit():
"test_cachedataset_persistent_workers",
"test_invertd",
"test_handler_post_processing",
"test_write_metrics_reports",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
14 changes: 10 additions & 4 deletions tests/test_evenly_divisible_all_gather_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch.distributed as dist

from monai.handlers.utils import evenly_divisible_all_gather
from monai.utils import evenly_divisible_all_gather
from tests.utils import DistCall, DistTestCase


Expand All @@ -27,15 +27,21 @@ def _run(self):
if dist.get_rank() == 0:
data1 = torch.tensor([[1, 2], [3, 4]])
data2 = torch.tensor([[1.0, 2.0]])
data3 = torch.tensor(7)

if dist.get_rank() == 1:
data1 = torch.tensor([[5, 6]])
data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]])
data3 = torch.tensor(8)

result1 = evenly_divisible_all_gather(data=data1)
result1 = evenly_divisible_all_gather(data=data1, concat=True)
torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]]))
result2 = evenly_divisible_all_gather(data=data2)
torch.testing.assert_allclose(result2, torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]))
result2 = evenly_divisible_all_gather(data=data2, concat=False)
for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]):
torch.testing.assert_allclose(r, e)
result3 = evenly_divisible_all_gather(data=data3, concat=False)
for r in result3:
self.assertEqual(r.ndimension(), 0)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions tests/test_handler_classification_saver_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
from ignite.engine import Engine

from monai.handlers import ClassificationSaver
from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion
from tests.utils import DistCall, DistTestCase


@SkipIfBeforePyTorchVersion((1, 7))
class DistributedHandlerClassificationSaver(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
def test_saved_content(self):
Expand Down
9 changes: 4 additions & 5 deletions tests/test_handler_metrics_saver_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from ignite.engine import Engine, Events

from monai.handlers import MetricsSaver
from monai.handlers.utils import evenly_divisible_all_gather
from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion
from monai.utils import evenly_divisible_all_gather
from tests.utils import DistCall, DistTestCase


@SkipIfBeforePyTorchVersion((1, 7))
class DistributedMetricsSaver(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
def test_content(self):
Expand Down Expand Up @@ -76,9 +75,9 @@ def _save_metrics1(engine):
@engine.on(Events.EPOCH_COMPLETED)
def _all_gather(engine):
scores = engine.state.metric_details["metric3"]
engine.state.metric_details["metric3"] = evenly_divisible_all_gather(data=scores)
engine.state.metric_details["metric3"] = evenly_divisible_all_gather(data=scores, concat=True)
scores = engine.state.metric_details["metric4"]
engine.state.metric_details["metric4"] = evenly_divisible_all_gather(data=scores)
engine.state.metric_details["metric4"] = evenly_divisible_all_gather(data=scores, concat=True)

metrics_saver.attach(engine)
engine.run(data, max_epochs=1)
Expand Down