Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import torch_xla.distributed.xla_multiprocessing as xmp


def all_gather(tensor, dim):
return xm.all_gather(tensor, dim=dim)


def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
Expand All @@ -14,6 +18,18 @@ def _mp_fn(index):
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)

cpu_result = result.cpu()
expected = torch.arange(0, world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print('xm.all_gather() produced wrong reductions', file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

compiled_all_gather = torch.compile(
all_gather, backend='torchxla_trace_once', fullgraph=True)
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = compiled_all_gather(ordinal_tensor, dim=0)

cpu_result = result.cpu()
expected = torch.arange(0, world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
Expand Down
33 changes: 31 additions & 2 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
import torch_xla
from torch_xla.experimental import pjrt
from torch_xla.experimental import tpu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.debug.metrics_saver as ms
import torch_xla.utils.utils as xu
Expand All @@ -26,6 +27,27 @@
_DEVICE_CONTEXTS = dict()
_DEVICE_CONTEXTS_LOCK = threading.Lock()

# Note [Dynamo WORLD_SIEZ and ORDINAL]
# Belows are workaround to cache the ordinal and world_size such that
# Dynamo won't do graph breaks when xm.xrt_world_size() and xm.get_ordinal() are called.
_WORLD_SIZE = None
_ORDINAL = None


def _init_world_size_ordinal():
global _WORLD_SIZE, _ORDINAL

if not pjrt.using_pjrt():
return

# We don't support V3-8. See Note [V3-8 Threading]
if pjrt.device_type() == 'TPU' and tpu.version() < 4:
return

if _WORLD_SIZE is None:
_WORLD_SIZE = xrt_world_size()
_ORDINAL = get_ordinal()


class DeviceContext(object):

Expand Down Expand Up @@ -90,6 +112,10 @@ def xrt_world_size(defval=1):
Returns:
The number of devices which is taking part of the replication.
"""
global _WORLD_SIZE
if _WORLD_SIZE is not None:
return _WORLD_SIZE

if pjrt.using_pjrt():
return pjrt.world_size()

Expand All @@ -109,6 +135,10 @@ def get_ordinal(defval=0):
Returns:
The replication ordinal of the current thread.
"""
global _ORDINAL
if _ORDINAL is not None:
return _ORDINAL

if pjrt.using_pjrt():
return pjrt.global_ordinal()

Expand Down Expand Up @@ -533,8 +563,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if pin_layout and xla_device_hw(
value.device) in ('TPU', 'GPU', 'XPU') and output == None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we had it because CPU was not supported at some point. Do you need to remove it because it will break dynamo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea.

if pin_layout and output == None:
# There is not an easy way to pin the all_gather layout on TPU and GPU, use
# all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
namespace torch_xla {
namespace {

// Note [V3-8 Threading]
// For V3-8 + PJRT, we have 4 processes and each process has 2 threads to manage
// the 8 cores. Therefore, we need different tokens for different threads.
std::unordered_map<int64_t, std::shared_ptr<torch::lazy::Value>>
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/experimental/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def _run_thread_per_device(
def _thread_fn(device: torch.device):
torch_xla._XLAC._xla_set_default_device(device)

# See Note Note [Dynamo WORLD_SIEZ and ORDINAL].
xm._init_world_size_ordinal()

return fn()

with concurrent.futures.ThreadPoolExecutor(
Expand Down