Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy

from itertools import chain
import unittest
from unittest.mock import patch
import math
Expand Down Expand Up @@ -337,6 +338,77 @@ def test_mark_sharding_partial_unordered(self):
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_tupled_partition_spec(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16).to(xm.xla_device())
xs.mark_sharding(t, mesh, ((0, 1),))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required for tupled partition spec")
def test_named_partial_tupled_partition_spec(self):
mesh = xs.Mesh(
range(self.n_devices), (1, 2, self.n_devices // 2), ('r', 'b', 'm'))
# Shard the first dimension on `r` and `b`, replicate the second dimension
t = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(t, mesh, (('r', 'b'), None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t),
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))

# Replicate the first dimension, shard the second on `b` and `m`
u = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))

# Replicate the first dimension, shard the second on `r` and `m`
v = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
device_order = chain(
range(0, self.n_devices, 2), range(1, self.n_devices, 2))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v),
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in device_order)))

# Replicate the first dimension, shard the second on `m` and `b`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So ordering of the tuple actually matters? Maybe we should comment it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure on the approach taken by JAX here... I'll add a comment for now, but we should try to match JAX's behavior on this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed the order matters in JAX as well:

>>> z = jax.device_put(x, NamedSharding(mesh, P(('a', 'b'))))
>>> jax.debug.visualize_array_sharding(z)
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
└───────────────────────┘
>>> z = jax.device_put(x, NamedSharding(mesh, P(('b', 'a'))))
>>> jax.debug.visualize_array_sharding(z)
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 3         │
└───────────────────────┘

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have this visualize_array_sharding haha, it looks so nice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol very good.

v = torch.randn(16, 16).to(xm.xla_device())
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
device_order = chain(
range(0, self.n_devices, 2), range(1, self.n_devices, 2))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in device_order)))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'Multiple devices required for tupled partition spec')
def test_multiple_tuples_in_spec(self):
mesh = xs.Mesh(
range(self.n_devices), (1, 2, self.n_devices // 2, 1),
('a', 'b', 'c', 'd'))
t = torch.randn(2, 2).to(xm.xla_device())
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'At least 2 devices needed for 2D mesh')
def test_3d_tensor_2d_mesh(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16, 16, 16).to(xm.xla_device())
xs.mark_sharding(t, mesh, (None, 0, 1))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))

def test_partial_replication_addmm(self):
device = xm.xla_device()
z_dim = 2 if self.n_devices >= 4 else 1
Expand Down
108 changes: 76 additions & 32 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
import itertools
from typing import Tuple, Union, List, Sequence, Any, Optional
from typing import Tuple, Union, List, Sequence, Any, Optional, Set
from enum import IntEnum


Expand Down Expand Up @@ -331,40 +331,68 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]],
return sharding_type


def _get_tile_assignment(mesh: Mesh,
partition_spec: Tuple[Union[int, None]]) -> List[int]:
if (None not in partition_spec) and (len(mesh.mesh_shape)
== len(partition_spec)):
return mesh.get_logical_mesh().transpose(partition_spec).tolist()
# Tile permutation is not necessary for partial replication.
return mesh.get_logical_mesh().tolist()
def _get_tile_assignment(
mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int,
None]]) -> np.ndarray:
"""
Permute the given mesh to create the tile assignment based on the partition
spec. Returns the tiling assignment as a numpy ndarray.

If the input partition_spec combines multiple logical mesh axes over a single
tensor axis, the resulting tiling assignment will combine the specified axes
into a single axis.
"""
# Flatten the partition spec and ensure that it is fully specified over the
# mesh for permutation.
tiled_dims = [x for x in partition_spec if x is not None]
permutation = np.hstack(tiled_dims).tolist() if tiled_dims else []
missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation))
tile_assignment = mesh.get_logical_mesh().transpose(permutation +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was shocked by the transpose grammar and then realized it's np.array. I'm curious what's the advantage to use np.array over cpu torch.tensor. This is just for discussion. There is no needs to make any changes.

missing_axes)

# For any tuples in the partition_spec, the grouped axes will be adjacent
# after the permutation. Combine these dimensions into a single axis.
for i, spec in enumerate(tiled_dims):
if isinstance(spec, tuple):
shape = tile_assignment.shape
tile_assignment = tile_assignment.reshape(shape[:i] + (-1,) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This -1 is the magic to infer the product of the tuple. lol.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

shape[i + len(spec):])

return tile_assignment


# Produce group assignment for partial replication. Partial replication tiles
# groups (a.k.a. sub-groups) where the shards are fully replicated within each
# sub-group. `replication_groups` is a list of groups as lists, where each group
# contains the participating device IDs. `group_assignment` describes the group
# placement and the overall mesh, where each element is the group ID.
def _get_group_assignment(
sharding_type: ShardingType, mesh: Mesh,
partition_spec: Tuple[Union[int, None]]) -> Tuple[List, List]:
# The tile_assignment should be the result of `_get_tile_assignment` so that all
# tiled dimensions are in the first axes and replicated dimensions are in the
# remaining axes.
def _get_group_assignment(sharding_type: ShardingType,
tile_assignment: np.ndarray, tensor_rank: int,
replicate_dims: Set[int]) -> Tuple[List, List]:
group_assignment = list()
replication_groups = list()
if sharding_type is ShardingType.PARTIAL:
# Shard across groups and replicate within subgroups; replicated dims
# will be used to group replication devices.
tile_dims = [d for d in partition_spec if d is not None]

group_list = [np.array(mesh.get_logical_mesh().tolist())]
tile_shape = tile_assignment.shape
# When creating the tile assignment, the mesh is permuted so that the first
# few axes are used for tiling.
tile_dims = range(tensor_rank - len(replicate_dims))
group_list = [tile_assignment]
for d in tile_dims:
_group_list = list()
for group_members in group_list:
_group_list += np.split(group_members, mesh.mesh_shape[d], d)
_group_list += np.split(group_members, tile_shape[d], d)
group_list = _group_list
replication_groups = [group.flatten().tolist() for group in group_list]

mesh_axis = itertools.count()
group_tile_shape = [
mesh.mesh_shape[d] if d is not None else 1 for d in partition_spec
1 if d in replicate_dims else tile_shape[next(mesh_axis)]
for d in range(tensor_rank)
]
group_assignment = np.arange(len(replication_groups)).reshape(
tuple(group_tile_shape)).tolist()
Expand All @@ -374,7 +402,11 @@ def _get_group_assignment(
def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
_partition_spec = list()
for p in partition_spec:
if (p is None) or (type(p) is int):
if type(p) is tuple:
assert not any(type(x) is tuple
for x in p), 'Partition spec cannot contain nested tuples'
_partition_spec.append(_translate_named_partition_spec(mesh, p))
elif (p is None) or (type(p) is int):
_partition_spec.append(p)
elif type(p) is str:
idx = mesh.get_axis_name_idx(p)
Expand All @@ -384,13 +416,13 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
else:
raise ValueError(
f"Spec type {type(p)} is not supported in partition spec")
return _partition_spec
return tuple(_partition_spec)


@xr.requires_pjrt
def mark_sharding(
t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
partition_spec: Tuple[Union[int, str, None]]) -> XLAShardedTensor:
partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor:
"""
Annotates the tensor provided with XLA partition spec. Internally,
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
Expand All @@ -399,8 +431,12 @@ def mark_sharding(

mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.

partition_spec (Tuple[int, str, None]): A tuple of device_mesh dimension index or `None`. Each index is an int or str if the mesh axis is named.
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or
`None`. Each index is an int, str if the mesh axis is named, or tuple of int or str.
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
When a tuple is specified, the corresponding input tensor axis will be sharded along all
logical axes in the tuple. Note that the order the mesh axes are specified in the tuple
will impact the resulting sharding.
For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
>> input = torch.randn(8, 10)
>> mesh_shape = (4, 2)
Expand All @@ -426,34 +462,38 @@ def mark_sharding(
assert mesh.size() == num_devices, \
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
partition_spec = _translate_named_partition_spec(mesh, partition_spec)
assert all((d >= 0 and d < len(mesh.mesh_shape)) for d in partition_spec if d), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
# We only allow fully specified `partition_spec` to be applicable, as opposed
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
# should be of the same rank as `t`. This is to support partial replication
# where the group assignment may vary with different input ranks.
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
specs = [d for d in partition_spec if d]
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(mesh, partition_spec)
if len(mesh.mesh_shape) > len(partition_spec):
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
sharding_type = ShardingType.PARTIAL
else:
sharding_type = _get_sharding_type(partition_spec, num_devices)
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
group_assignment, replication_groups = _get_group_assignment(
sharding_type, mesh, partition_spec)
sharding_type, tile_assignment, len(partition_spec), replicate_dims)

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, tile_assignment,
torch_xla._XLAC._xla_mark_sharding(t.global_tensor,
tile_assignment.tolist(),
group_assignment, replication_groups,
int(sharding_type))
return t
torch_xla._XLAC._xla_mark_sharding(t, tile_assignment, group_assignment,
replication_groups, int(sharding_type))
torch_xla._XLAC._xla_mark_sharding(t, tile_assignment.tolist(),
group_assignment, replication_groups,
int(sharding_type))
return XLAShardedTensor(t)


Expand Down Expand Up @@ -491,12 +531,16 @@ class ShardingSpec:

@xr.requires_pjrt
def __post_init__(self):
partition_spec, mesh = self.partition_spec, self.mesh
self._tile_assignment = _get_tile_assignment(mesh, partition_spec)
mesh = self.mesh
partition_spec = _translate_named_partition_spec(mesh, self.partition_spec)
tile_assignment = _get_tile_assignment(mesh, partition_spec)
self._tile_assignment = tile_assignment.tolist()
self._sharding_type = _get_sharding_type(partition_spec,
xr.global_runtime_device_count())
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
self._group_assignment, self._replication_groups = _get_group_assignment(
self._sharding_type, mesh, partition_spec)
self._sharding_type, tile_assignment, len(partition_spec),
replicate_dims)

def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
"""
Expand Down