Skip to content
57 changes: 57 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy

import unittest
from unittest.mock import patch
import math
import numpy as np
import os
Expand All @@ -10,6 +11,7 @@
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.experimental.xla_sharding as xs
Expand Down Expand Up @@ -452,6 +454,61 @@ def test_no_sharding(self):
t3_expected = [9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0]
self.assertEqual(t3.tolist()[0], t3_expected)

@unittest.skipUnless(
xm.get_xla_supported_devices("TPU"),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_hybrid_mesh_shape(self):
mesh = self._get_mesh((1, self.n_devices))
hybrid_mesh = self._get_hybrid_mesh((1, self.n_devices))
# Check if shape of hybrid mesh matches mesh
self.assertEqual(mesh.get_logical_mesh().shape,
hybrid_mesh.get_logical_mesh().shape)

@patch('torch_xla.runtime.global_device_attributes')
@patch('torch_xla.core.xla_model.xla_device_hw')
def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):
# mock device attributes for 2 slices of v4-8
num_slices = 2
xla_device_mock.return_value = "TPU"
device_attributes_mock.return_value = [{
'coords': [0, 0, 0],
'core_on_chip': 0,
'slice_index': 0
}, {
'core_on_chip': 0,
'coords': [1, 0, 0],
'slice_index': 0
}, {
'slice_index': 0,
'core_on_chip': 0,
'coords': [0, 1, 0]
}, {
'coords': [1, 1, 0],
'core_on_chip': 0,
'slice_index': 0
}, {
'coords': [0, 0, 0],
'slice_index': 1,
'core_on_chip': 0
}, {
'coords': [1, 0, 0],
'slice_index': 1,
'core_on_chip': 0
}, {
'coords': [0, 1, 0],
'slice_index': 1,
'core_on_chip': 0
}, {
'core_on_chip': 0,
'coords': [1, 1, 0],
'slice_index': 1
}]
hybrid_mesh = xs.HybridMesh(
ici_mesh_shape=(1, 4), dcn_mesh_shape=(num_slices, 1))
print(hybrid_mesh.get_logical_mesh())
self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this result respect the _create_device_mesh_for_nd_torus algorithm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have confirmed this with the jax's mesh

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the ici_mesh_shap=(2, 2)? I think that can better show how the algorithm works?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed ici_mesh_shape

[[0, 2, 1, 3], [4, 6, 5, 7]])


if __name__ == '__main__':
test = unittest.main()
Expand Down
3 changes: 3 additions & 0 deletions test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ def _get_mesh(self, mesh_shape, device_ids=None):
device_ids = self.device_ids
assert len(device_ids) == self.n_devices
return xs.Mesh(device_ids, mesh_shape)

def _get_hybrid_mesh(self, ici_mesh_shape):
return xs.HybridMesh(ici_mesh_shape=ici_mesh_shape)
154 changes: 151 additions & 3 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
import torch
import torch_xla
Expand All @@ -8,7 +8,8 @@
import torch_xla.runtime as xr

import numpy as np
from typing import Tuple, Union, List, Any
import itertools
from typing import Tuple, Union, List, Sequence, Any, Optional
from enum import IntEnum


Expand Down Expand Up @@ -64,12 +65,159 @@ def size(self):

def shape(self):
return OrderedDict(
(name, size) for name, size in zip(self.axis_name, self.mesh_shape))
(name, size) for name, size in zip(self.axis_names, self.mesh_shape))

def get_logical_mesh(self):
return self.device_ids.reshape(self.mesh_shape)


class HybridMesh(Mesh):
"""Creates a hybrid device mesh of devices connected with ICI and DCN networks.
The shape of logical mesh should be ordered by increasing network-intensity
e.g. [replica, data, model] where mdl has the most network communication
requirements.

Args:
ici_mesh_shape: shape of the logical mesh for inner connected devices.
dcn_mesh_shape: shape of logical mesh for outer connected devices.

Example:
# This example is assuming 2 slices of v4-8.
ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
dcn_mesh_shape = (2, 1, 1)

mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
print(mesh.shape())
>> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])
"""
ici_mesh_shape: Tuple[int, ...]
dcn_mesh_shape: Tuple[int, ...]

def __init__(self,
*,
ici_mesh_shape: Tuple[int, ...],
dcn_mesh_shape: Tuple[int, ...] = None,
axis_names: Tuple[str, ...] = None):
if dcn_mesh_shape == None:
dcn_mesh_shape = tuple([1] * len(ici_mesh_shape))
assert len(ici_mesh_shape) == len(dcn_mesh_shape)
mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)])
self.device_attributes = xr.global_device_attributes()
if 'slice_index' in self.device_attributes[0] and np.prod(
dcn_mesh_shape) == 1:
raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice')
if 'slice_index' not in self.device_attributes[0] and np.prod(
dcn_mesh_shape) > 1:
raise ValueError('Invalid dcn_mesh_shape for single slice mesh')
self.ici_mesh_shape = ici_mesh_shape
self.dcn_mesh_shape = dcn_mesh_shape
if np.prod(dcn_mesh_shape) > 1 and 'slice_index' in self.device_attributes[
0]: # multislice
mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape,
self.dcn_mesh_shape)
else:
mesh = self._create_device_mesh(self.ici_mesh_shape)
device_ids = mesh.flatten()
super().__init__(device_ids, mesh_shape, axis_names)

def _get_physical_tpu_mesh(self, devices: Sequence[Any]) -> np.ndarray:
r"""Rearrange TPU devices in a slice into a physical mesh."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add:
1.

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
  1. The following description of the function:
  r"""Rearrange TPU devices in a slice into a physical mesh.

  Args:
    devices: A list of device logical ordinals in a TPU slice.

  Returns:
    A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
      v2 and v3, global_z is instead cores_per_chip (i.e., 2).
  """

assert xm.xla_device_hw(xm.xla_device()) == 'TPU'
# coords is a 3-dims tuple representing the device in physical mesh
device_coords = [self.device_attributes[d]['coords'] for d in devices]
dims = tuple(d + 1 for d in max(device_coords))
out = np.empty(dims, dtype=object)
for coords, d in zip(device_coords, devices):
out[coords[0], coords[1], coords[2]] = d
return out

def _create_device_mesh_for_nd_torus(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how this function optimize the performance according to the TPU physical topology? What's the algorithm? Is it the inner ring has the highest performance, so we should assign the back of the mesh_shape to it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking with Mohit offline. The rule is that the TPU topology is always 3D. And the inner 2D tensors have a faster ICI than the ones connect across them. Therefore, we should group the most speed demanding rank, i.e., highest rank of the mesh, to the inner 2D tensors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I read more into the code. This algorithm seems quite restrict:

  1. It only works with mapping a 2D or 3D logical mesh into the 3D physical mesh.
  2. Then for 3D mesh, I think the logical mesh needs to be a transpose of the physical mesh.
  3. Then for 2D mesh, it's just trying to map a combination of the axes into each of the dimension of the logical mesh.

After these simple rules, it then makes sure that devices that are physically close to each other are assigned close to each other in the logical mesh as well. For example, assuming the logical mesh is 2D, the devices that are in mesh[0] are always be a 2D slice of the 3D physical mesh.

If my understanding is correct, @khatwanimohit can you polish my comments and make it into the comment of this helper?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add:

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.

self, physical_mesh: np.ndarray,
mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
# Remaining physical axes to be assigned to logical axes.
assignable_physical_mesh = list(physical_mesh.shape)
# Map each logical axis to a subset of physical axes.
assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape]
# Assign logical axes from highest network intensity to lowest.
# `mesh_shape` is assumed to ordered by lowest network intensity first, so
# reverse it first.
for logical_axis_index, logical_axis_size in reversed(
list(enumerate(mesh_shape))):
for num_axes in range(3, 0, -1):
axes = itertools.combinations(assignable_physical_mesh, num_axes)
indices = itertools.combinations(
range(len(assignable_physical_mesh)), num_axes)
for c_axes, c_indices in zip(axes, indices):
if np.product(c_axes) == logical_axis_size:
assignment[logical_axis_index] = c_indices
# Zero the assigned physical axes.
assignable_physical_mesh = [
0 if i in c_indices else v
for i, v in enumerate(assignable_physical_mesh)
]
break
if assignment[logical_axis_index]:
# We already found an assignment from one candidate above.
break
else:
# If the num_axes for loop did not break, i.e. none of the candidates work
# goto here with this while-else construct.
if logical_axis_size > 1:
raise NotImplementedError(
'Failed to find assignment for logical_axis_index'
f' {logical_axis_index} of size {logical_axis_size} with remaining'
f' assignable mesh {assignable_physical_mesh}. The size of each'
' axis in your logical mesh must be equal to the product of'
' some subset of the physical mesh axis sizes. E.g logical mesh (4,'
' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.'
)
# Flatten the assignment
transpose: List[int] = []
for x in assignment:
for y in x:
transpose.append(int(y))
return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment

def _create_device_mesh(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't mention this one given your logic is quite different. I suggest you can undo it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the comment

mesh_shape: Sequence[int],
devices: Sequence[Any] = None) -> np.ndarray:
if devices is None:
devices = np.arange(xr.global_device_count())
if np.prod(mesh_shape) != len(devices):
raise ValueError(
f'Number of devices {len(devices)} must equal the product '
f'of mesh_shape {mesh_shape}')
physical_mesh = self._get_physical_tpu_mesh(devices)
device_mesh, assignment = self._create_device_mesh_for_nd_torus(
physical_mesh, mesh_shape)
return device_mesh

def _create_hybrid_device_mesh(self, ici_mesh_shape: Sequence[int],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add:
1.

This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288.
  1. And the follow function description:
"""Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.

  Args:
    ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered
      by increasing network intensity, e.g. [replica, data, mdl] where mdl has
      the most network communication requirements.
    dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
      in the same order as mesh_shape.

  Returns:
    A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape
    that can be fed into HybridMesh for hybrid parallelism.
  """

dcn_mesh_shape: Sequence[int]) -> np.ndarray:
"""Creates a device mesh based on ici and dcn mesh shape.
"""
granule_dict = defaultdict(list)
for d, dev in enumerate(self.device_attributes):
granule_dict[dev['slice_index']].append(d)
# sorts devices based on slice_index.
granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
f'Number of slices {len(granules)} must equal the product of '
f'dcn_mesh_shape {dcn_mesh_shape}')
# creates a seperate internal mesh for each slice.
per_granule_meshes = [
self._create_device_mesh(ici_mesh_shape, granule)
for granule in granules
]
granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
blocks = np.vectorize(
lambda i: per_granule_meshes[i], otypes=[object])(
granule_mesh)
device_mesh = np.block(blocks.tolist())
return device_mesh


class ShardingType(IntEnum):
# ShardingType enum ID maps to OpSharidng.Type (https://shorturl.at/pvAJX)
REPLICATED = 0
Expand Down