-
Notifications
You must be signed in to change notification settings - Fork 561
Support tuples in partition spec #5488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4bb16b9
9aa32d6
17416da
6e19717
8493547
2ac4dd4
b7f5a61
54258af
1006a89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import copy | ||
|
|
||
| from itertools import chain | ||
| import unittest | ||
| from unittest.mock import patch | ||
| import math | ||
|
|
@@ -337,6 +338,77 @@ def test_mark_sharding_partial_unordered(self): | |
| actual = (xt1 + t2).cpu() | ||
| self.assertTrue(torch.allclose(expected, actual)) | ||
|
|
||
| @unittest.skipUnless(xr.global_runtime_device_count() > 1, | ||
| "Multiple devices required for tupled partition spec") | ||
| def test_tupled_partition_spec(self): | ||
| mesh = self._get_mesh((2, self.n_devices // 2)) | ||
| t = torch.randn(16).to(xm.xla_device()) | ||
| xs.mark_sharding(t, mesh, ((0, 1),)) | ||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" % | ||
| (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) | ||
|
|
||
| @unittest.skipUnless(xr.global_runtime_device_count() >= 4, | ||
| "Multiple devices required for tupled partition spec") | ||
| def test_named_partial_tupled_partition_spec(self): | ||
| mesh = xs.Mesh( | ||
| range(self.n_devices), (1, 2, self.n_devices // 2), ('r', 'b', 'm')) | ||
| # Shard the first dimension on `r` and `b`, replicate the second dimension | ||
| t = torch.randn(16, 16).to(xm.xla_device()) | ||
| xs.mark_sharding(t, mesh, (('r', 'b'), None)) | ||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(t), | ||
| "{devices=[2,1,%d]%s last_tile_dim_replicate}" % | ||
| (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) | ||
|
|
||
| # Replicate the first dimension, shard the second on `b` and `m` | ||
| u = torch.randn(16, 16).to(xm.xla_device()) | ||
| xs.mark_sharding(u, mesh, (None, ('b', 'm'))) | ||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" % | ||
| (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) | ||
|
|
||
| # Replicate the first dimension, shard the second on `r` and `m` | ||
| v = torch.randn(16, 16).to(xm.xla_device()) | ||
| xs.mark_sharding(v, mesh, (None, ('r', 'm'))) | ||
| device_order = chain( | ||
| range(0, self.n_devices, 2), range(1, self.n_devices, 2)) | ||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(v), | ||
| "{devices=[1,%d,2]%s last_tile_dim_replicate}" % | ||
| (self.n_devices // 2, ','.join(str(x) for x in device_order))) | ||
|
|
||
| # Replicate the first dimension, shard the second on `m` and `b` | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So ordering of the tuple actually matters? Maybe we should comment it?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure on the approach taken by JAX here... I'll add a comment for now, but we should try to match JAX's behavior on this.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirmed the order matters in JAX as well:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should have this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lol very good. |
||
| v = torch.randn(16, 16).to(xm.xla_device()) | ||
| xs.mark_sharding(v, mesh, (None, ('m', 'b'))) | ||
| device_order = chain( | ||
| range(0, self.n_devices, 2), range(1, self.n_devices, 2)) | ||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" % | ||
| (self.n_devices, ','.join(str(x) for x in device_order))) | ||
|
|
||
| @unittest.skipUnless(xr.global_runtime_device_count() > 1, | ||
| 'Multiple devices required for tupled partition spec') | ||
| def test_multiple_tuples_in_spec(self): | ||
| mesh = xs.Mesh( | ||
| range(self.n_devices), (1, 2, self.n_devices // 2, 1), | ||
| ('a', 'b', 'c', 'd')) | ||
| t = torch.randn(2, 2).to(xm.xla_device()) | ||
| xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd'))) | ||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" % | ||
| (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) | ||
|
|
||
| @unittest.skipUnless(xr.global_runtime_device_count() > 1, | ||
| 'At least 2 devices needed for 2D mesh') | ||
| def test_3d_tensor_2d_mesh(self): | ||
| mesh = self._get_mesh((2, self.n_devices // 2)) | ||
| t = torch.randn(16, 16, 16).to(xm.xla_device()) | ||
| xs.mark_sharding(t, mesh, (None, 0, 1)) | ||
alanwaketan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.assertEqual( | ||
| torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' % | ||
| (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) | ||
|
|
||
| def test_partial_replication_addmm(self): | ||
| device = xm.xla_device() | ||
| z_dim = 2 if self.n_devices >= 4 else 1 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |
|
|
||
| import numpy as np | ||
| import itertools | ||
| from typing import Tuple, Union, List, Sequence, Any, Optional | ||
| from typing import Tuple, Union, List, Sequence, Any, Optional, Set | ||
| from enum import IntEnum | ||
|
|
||
|
|
||
|
|
@@ -331,40 +331,68 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]], | |
| return sharding_type | ||
|
|
||
|
|
||
| def _get_tile_assignment(mesh: Mesh, | ||
| partition_spec: Tuple[Union[int, None]]) -> List[int]: | ||
| if (None not in partition_spec) and (len(mesh.mesh_shape) | ||
| == len(partition_spec)): | ||
| return mesh.get_logical_mesh().transpose(partition_spec).tolist() | ||
| # Tile permutation is not necessary for partial replication. | ||
| return mesh.get_logical_mesh().tolist() | ||
| def _get_tile_assignment( | ||
| mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int, | ||
| None]]) -> np.ndarray: | ||
| """ | ||
| Permute the given mesh to create the tile assignment based on the partition | ||
| spec. Returns the tiling assignment as a numpy ndarray. | ||
|
|
||
| If the input partition_spec combines multiple logical mesh axes over a single | ||
| tensor axis, the resulting tiling assignment will combine the specified axes | ||
| into a single axis. | ||
| """ | ||
| # Flatten the partition spec and ensure that it is fully specified over the | ||
| # mesh for permutation. | ||
| tiled_dims = [x for x in partition_spec if x is not None] | ||
| permutation = np.hstack(tiled_dims).tolist() if tiled_dims else [] | ||
| missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation)) | ||
| tile_assignment = mesh.get_logical_mesh().transpose(permutation + | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was shocked by the transpose grammar and then realized it's np.array. I'm curious what's the advantage to use np.array over cpu torch.tensor. This is just for discussion. There is no needs to make any changes. |
||
| missing_axes) | ||
|
|
||
| # For any tuples in the partition_spec, the grouped axes will be adjacent | ||
| # after the permutation. Combine these dimensions into a single axis. | ||
| for i, spec in enumerate(tiled_dims): | ||
| if isinstance(spec, tuple): | ||
| shape = tile_assignment.shape | ||
| tile_assignment = tile_assignment.reshape(shape[:i] + (-1,) + | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
| shape[i + len(spec):]) | ||
|
|
||
| return tile_assignment | ||
|
|
||
|
|
||
| # Produce group assignment for partial replication. Partial replication tiles | ||
| # groups (a.k.a. sub-groups) where the shards are fully replicated within each | ||
| # sub-group. `replication_groups` is a list of groups as lists, where each group | ||
| # contains the participating device IDs. `group_assignment` describes the group | ||
| # placement and the overall mesh, where each element is the group ID. | ||
| def _get_group_assignment( | ||
| sharding_type: ShardingType, mesh: Mesh, | ||
| partition_spec: Tuple[Union[int, None]]) -> Tuple[List, List]: | ||
| # The tile_assignment should be the result of `_get_tile_assignment` so that all | ||
| # tiled dimensions are in the first axes and replicated dimensions are in the | ||
| # remaining axes. | ||
| def _get_group_assignment(sharding_type: ShardingType, | ||
| tile_assignment: np.ndarray, tensor_rank: int, | ||
| replicate_dims: Set[int]) -> Tuple[List, List]: | ||
| group_assignment = list() | ||
| replication_groups = list() | ||
| if sharding_type is ShardingType.PARTIAL: | ||
| # Shard across groups and replicate within subgroups; replicated dims | ||
| # will be used to group replication devices. | ||
| tile_dims = [d for d in partition_spec if d is not None] | ||
|
|
||
| group_list = [np.array(mesh.get_logical_mesh().tolist())] | ||
| tile_shape = tile_assignment.shape | ||
| # When creating the tile assignment, the mesh is permuted so that the first | ||
| # few axes are used for tiling. | ||
| tile_dims = range(tensor_rank - len(replicate_dims)) | ||
| group_list = [tile_assignment] | ||
| for d in tile_dims: | ||
| _group_list = list() | ||
| for group_members in group_list: | ||
| _group_list += np.split(group_members, mesh.mesh_shape[d], d) | ||
| _group_list += np.split(group_members, tile_shape[d], d) | ||
| group_list = _group_list | ||
| replication_groups = [group.flatten().tolist() for group in group_list] | ||
|
|
||
| mesh_axis = itertools.count() | ||
| group_tile_shape = [ | ||
| mesh.mesh_shape[d] if d is not None else 1 for d in partition_spec | ||
| 1 if d in replicate_dims else tile_shape[next(mesh_axis)] | ||
| for d in range(tensor_rank) | ||
| ] | ||
| group_assignment = np.arange(len(replication_groups)).reshape( | ||
| tuple(group_tile_shape)).tolist() | ||
|
|
@@ -374,7 +402,11 @@ def _get_group_assignment( | |
| def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): | ||
| _partition_spec = list() | ||
| for p in partition_spec: | ||
| if (p is None) or (type(p) is int): | ||
| if type(p) is tuple: | ||
| assert not any(type(x) is tuple | ||
| for x in p), 'Partition spec cannot contain nested tuples' | ||
| _partition_spec.append(_translate_named_partition_spec(mesh, p)) | ||
| elif (p is None) or (type(p) is int): | ||
| _partition_spec.append(p) | ||
| elif type(p) is str: | ||
| idx = mesh.get_axis_name_idx(p) | ||
|
|
@@ -384,13 +416,13 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): | |
| else: | ||
| raise ValueError( | ||
| f"Spec type {type(p)} is not supported in partition spec") | ||
| return _partition_spec | ||
| return tuple(_partition_spec) | ||
|
|
||
|
|
||
| @xr.requires_pjrt | ||
| def mark_sharding( | ||
| t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, | ||
| partition_spec: Tuple[Union[int, str, None]]) -> XLAShardedTensor: | ||
| partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: | ||
| """ | ||
| Annotates the tensor provided with XLA partition spec. Internally, | ||
| it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. | ||
|
|
@@ -399,8 +431,12 @@ def mark_sharding( | |
|
|
||
| mesh (Mesh): describes the logical XLA device topology and the underlying device IDs. | ||
|
|
||
| partition_spec (Tuple[int, str, None]): A tuple of device_mesh dimension index or `None`. Each index is an int or str if the mesh axis is named. | ||
| This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). | ||
| partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or | ||
| `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str. | ||
| This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). | ||
| When a tuple is specified, the corresponding input tensor axis will be sharded along all | ||
| logical axes in the tuple. Note that the order the mesh axes are specified in the tuple | ||
| will impact the resulting sharding. | ||
| For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise. | ||
| >> input = torch.randn(8, 10) | ||
| >> mesh_shape = (4, 2) | ||
|
|
@@ -426,34 +462,38 @@ def mark_sharding( | |
| assert mesh.size() == num_devices, \ | ||
| f"{mesh.mesh_shape} is not mappable over {num_devices} devices." | ||
| partition_spec = _translate_named_partition_spec(mesh, partition_spec) | ||
| assert all((d >= 0 and d < len(mesh.mesh_shape)) for d in partition_spec if d), \ | ||
| f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." | ||
| # We only allow fully specified `partition_spec` to be applicable, as opposed | ||
| # to filling in the unspecified replicated dims. Fully specified `partiion_spec` | ||
| # should be of the same rank as `t`. This is to support partial replication | ||
| # where the group assignment may vary with different input ranks. | ||
| assert len(t.shape) == len(partition_spec), \ | ||
| f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." | ||
| specs = [d for d in partition_spec if d] | ||
| flat_specs = np.hstack([d for d in partition_spec]) | ||
| specs = [d for d in flat_specs if d is not None] | ||
| assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ | ||
| f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." | ||
| assert len(specs) == len(np.unique(specs)), \ | ||
| f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." | ||
|
|
||
| tile_assignment = _get_tile_assignment(mesh, partition_spec) | ||
| if len(mesh.mesh_shape) > len(partition_spec): | ||
| if len(tile_assignment.shape) > len(partition_spec): | ||
| # Use partial replication for sharding a tensor over a higher-rank mesh | ||
| sharding_type = ShardingType.PARTIAL | ||
| else: | ||
| sharding_type = _get_sharding_type(partition_spec, num_devices) | ||
| replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} | ||
| group_assignment, replication_groups = _get_group_assignment( | ||
| sharding_type, mesh, partition_spec) | ||
| sharding_type, tile_assignment, len(partition_spec), replicate_dims) | ||
|
|
||
| if isinstance(t, XLAShardedTensor): | ||
| torch_xla._XLAC._xla_mark_sharding(t.global_tensor, tile_assignment, | ||
| torch_xla._XLAC._xla_mark_sharding(t.global_tensor, | ||
| tile_assignment.tolist(), | ||
| group_assignment, replication_groups, | ||
| int(sharding_type)) | ||
| return t | ||
| torch_xla._XLAC._xla_mark_sharding(t, tile_assignment, group_assignment, | ||
| replication_groups, int(sharding_type)) | ||
| torch_xla._XLAC._xla_mark_sharding(t, tile_assignment.tolist(), | ||
| group_assignment, replication_groups, | ||
| int(sharding_type)) | ||
| return XLAShardedTensor(t) | ||
|
|
||
|
|
||
|
|
@@ -491,12 +531,16 @@ class ShardingSpec: | |
|
|
||
| @xr.requires_pjrt | ||
| def __post_init__(self): | ||
| partition_spec, mesh = self.partition_spec, self.mesh | ||
| self._tile_assignment = _get_tile_assignment(mesh, partition_spec) | ||
| mesh = self.mesh | ||
| partition_spec = _translate_named_partition_spec(mesh, self.partition_spec) | ||
| tile_assignment = _get_tile_assignment(mesh, partition_spec) | ||
| self._tile_assignment = tile_assignment.tolist() | ||
| self._sharding_type = _get_sharding_type(partition_spec, | ||
| xr.global_runtime_device_count()) | ||
| replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} | ||
| self._group_assignment, self._replication_groups = _get_group_assignment( | ||
| self._sharding_type, mesh, partition_spec) | ||
| self._sharding_type, tile_assignment, len(partition_spec), | ||
| replicate_dims) | ||
|
|
||
| def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: | ||
| """ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.