-
Notifications
You must be signed in to change notification settings - Fork 561
[SPMD] Hybrid Device mesh creation #5147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
b37d0e3
3bd182c
ad91169
7b264ca
9f6d86c
8f55df8
c457f6c
d71df3a
ef665e9
9c6d8ab
632cbbb
572548b
640d0b3
abf04dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| import os | ||
| from collections import OrderedDict | ||
| from collections import OrderedDict, defaultdict | ||
| from dataclasses import dataclass, field | ||
| import torch | ||
| import torch_xla | ||
|
|
@@ -8,7 +8,8 @@ | |
| import torch_xla.runtime as xr | ||
|
|
||
| import numpy as np | ||
| from typing import Tuple, Union, List, Any | ||
| import itertools | ||
| from typing import Tuple, Union, List, Sequence, Any, Optional | ||
| from enum import IntEnum | ||
|
|
||
|
|
||
|
|
@@ -64,12 +65,159 @@ def size(self): | |
|
|
||
| def shape(self): | ||
| return OrderedDict( | ||
| (name, size) for name, size in zip(self.axis_name, self.mesh_shape)) | ||
| (name, size) for name, size in zip(self.axis_names, self.mesh_shape)) | ||
|
|
||
| def get_logical_mesh(self): | ||
| return self.device_ids.reshape(self.mesh_shape) | ||
|
|
||
|
|
||
| class HybridMesh(Mesh): | ||
| """Creates a hybrid device mesh of devices connected with ICI and DCN networks. | ||
| The shape of logical mesh should be ordered by increasing network-intensity | ||
| e.g. [replica, data, model] where mdl has the most network communication | ||
| requirements. | ||
|
|
||
| Args: | ||
| ici_mesh_shape: shape of the logical mesh for inner connected devices. | ||
| dcn_mesh_shape: shape of logical mesh for outer connected devices. | ||
|
|
||
| Example: | ||
| # This example is assuming 2 slices of v4-8. | ||
| ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor) | ||
| dcn_mesh_shape = (2, 1, 1) | ||
|
|
||
| mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor')) | ||
| print(mesh.shape()) | ||
| >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)]) | ||
| """ | ||
| ici_mesh_shape: Tuple[int, ...] | ||
| dcn_mesh_shape: Tuple[int, ...] | ||
|
|
||
| def __init__(self, | ||
| *, | ||
| ici_mesh_shape: Tuple[int, ...], | ||
| dcn_mesh_shape: Tuple[int, ...] = None, | ||
| axis_names: Tuple[str, ...] = None): | ||
| if dcn_mesh_shape == None: | ||
| dcn_mesh_shape = tuple([1] * len(ici_mesh_shape)) | ||
| assert len(ici_mesh_shape) == len(dcn_mesh_shape) | ||
khatwanimohit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mesh_shape = tuple([x * y for x, y in zip(ici_mesh_shape, dcn_mesh_shape)]) | ||
| self.device_attributes = xr.global_device_attributes() | ||
| if 'slice_index' in self.device_attributes[0] and np.prod( | ||
| dcn_mesh_shape) == 1: | ||
| raise ValueError('Provide dcn_mesh_shape to create a mesh for multislice') | ||
| if 'slice_index' not in self.device_attributes[0] and np.prod( | ||
| dcn_mesh_shape) > 1: | ||
| raise ValueError('Invalid dcn_mesh_shape for single slice mesh') | ||
| self.ici_mesh_shape = ici_mesh_shape | ||
| self.dcn_mesh_shape = dcn_mesh_shape | ||
| if np.prod(dcn_mesh_shape) > 1 and 'slice_index' in self.device_attributes[ | ||
| 0]: # multislice | ||
| mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape, | ||
| self.dcn_mesh_shape) | ||
| else: | ||
| mesh = self._create_device_mesh(self.ici_mesh_shape) | ||
| device_ids = mesh.flatten() | ||
| super().__init__(device_ids, mesh_shape, axis_names) | ||
|
|
||
| def _get_physical_tpu_mesh(self, devices: Sequence[Any]) -> np.ndarray: | ||
| r"""Rearrange TPU devices in a slice into a physical mesh.""" | ||
|
||
| assert xm.xla_device_hw(xm.xla_device()) == 'TPU' | ||
| # coords is a 3-dims tuple representing the device in physical mesh | ||
| device_coords = [self.device_attributes[d]['coords'] for d in devices] | ||
khatwanimohit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dims = tuple(d + 1 for d in max(device_coords)) | ||
| out = np.empty(dims, dtype=object) | ||
| for coords, d in zip(device_coords, devices): | ||
| out[coords[0], coords[1], coords[2]] = d | ||
alanwaketan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return out | ||
|
|
||
| def _create_device_mesh_for_nd_torus( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain how this function optimize the performance according to the TPU physical topology? What's the algorithm? Is it the inner ring has the highest performance, so we should assign the back of the mesh_shape to it?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Speaking with Mohit offline. The rule is that the TPU topology is always 3D. And the inner 2D tensors have a faster ICI than the ones connect across them. Therefore, we should group the most speed demanding rank, i.e., highest rank of the mesh, to the inner 2D tensors.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that I read more into the code. This algorithm seems quite restrict:
After these simple rules, it then makes sure that devices that are physically close to each other are assigned close to each other in the logical mesh as well. For example, assuming the logical mesh is 2D, the devices that are in mesh[0] are always be a 2D slice of the 3D physical mesh. If my understanding is correct, @khatwanimohit can you polish my comments and make it into the comment of this helper?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can add: |
||
| self, physical_mesh: np.ndarray, | ||
| mesh_shape: Sequence[int]) -> Tuple[np.ndarray, List[Tuple[int, ...]]]: | ||
| # Remaining physical axes to be assigned to logical axes. | ||
| assignable_physical_mesh = list(physical_mesh.shape) | ||
| # Map each logical axis to a subset of physical axes. | ||
| assignment: List[Tuple[int, ...]] = [() for _ in mesh_shape] | ||
| # Assign logical axes from highest network intensity to lowest. | ||
| # `mesh_shape` is assumed to ordered by lowest network intensity first, so | ||
khatwanimohit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # reverse it first. | ||
| for logical_axis_index, logical_axis_size in reversed( | ||
| list(enumerate(mesh_shape))): | ||
| for num_axes in range(3, 0, -1): | ||
| axes = itertools.combinations(assignable_physical_mesh, num_axes) | ||
| indices = itertools.combinations( | ||
| range(len(assignable_physical_mesh)), num_axes) | ||
| for c_axes, c_indices in zip(axes, indices): | ||
| if np.product(c_axes) == logical_axis_size: | ||
| assignment[logical_axis_index] = c_indices | ||
| # Zero the assigned physical axes. | ||
| assignable_physical_mesh = [ | ||
| 0 if i in c_indices else v | ||
| for i, v in enumerate(assignable_physical_mesh) | ||
| ] | ||
| break | ||
| if assignment[logical_axis_index]: | ||
| # We already found an assignment from one candidate above. | ||
| break | ||
| else: | ||
| # If the num_axes for loop did not break, i.e. none of the candidates work | ||
| # goto here with this while-else construct. | ||
| if logical_axis_size > 1: | ||
| raise NotImplementedError( | ||
| 'Failed to find assignment for logical_axis_index' | ||
| f' {logical_axis_index} of size {logical_axis_size} with remaining' | ||
| f' assignable mesh {assignable_physical_mesh}. The size of each' | ||
| ' axis in your logical mesh must be equal to the product of' | ||
| ' some subset of the physical mesh axis sizes. E.g logical mesh (4,' | ||
| ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.' | ||
| ) | ||
| # Flatten the assignment | ||
| transpose: List[int] = [] | ||
| for x in assignment: | ||
| for y in x: | ||
| transpose.append(int(y)) | ||
| return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment | ||
|
|
||
| def _create_device_mesh(self, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't mention this one given your logic is quite different. I suggest you can undo it.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed the comment |
||
| mesh_shape: Sequence[int], | ||
| devices: Sequence[Any] = None) -> np.ndarray: | ||
| if devices is None: | ||
| devices = np.arange(xr.global_device_count()) | ||
| if np.prod(mesh_shape) != len(devices): | ||
| raise ValueError( | ||
| f'Number of devices {len(devices)} must equal the product ' | ||
| f'of mesh_shape {mesh_shape}') | ||
| physical_mesh = self._get_physical_tpu_mesh(devices) | ||
khatwanimohit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| device_mesh, assignment = self._create_device_mesh_for_nd_torus( | ||
| physical_mesh, mesh_shape) | ||
| return device_mesh | ||
|
|
||
| def _create_hybrid_device_mesh(self, ici_mesh_shape: Sequence[int], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add:
|
||
| dcn_mesh_shape: Sequence[int]) -> np.ndarray: | ||
| """Creates a device mesh based on ici and dcn mesh shape. | ||
| """ | ||
| granule_dict = defaultdict(list) | ||
| for d, dev in enumerate(self.device_attributes): | ||
| granule_dict[dev['slice_index']].append(d) | ||
| # sorts devices based on slice_index. | ||
| granules = list(granule_dict[key] for key in sorted(granule_dict.keys())) | ||
| if np.prod(dcn_mesh_shape) != len(granules): | ||
| raise ValueError( | ||
| f'Number of slices {len(granules)} must equal the product of ' | ||
| f'dcn_mesh_shape {dcn_mesh_shape}') | ||
| # creates a seperate internal mesh for each slice. | ||
| per_granule_meshes = [ | ||
| self._create_device_mesh(ici_mesh_shape, granule) | ||
| for granule in granules | ||
| ] | ||
| granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) | ||
| blocks = np.vectorize( | ||
| lambda i: per_granule_meshes[i], otypes=[object])( | ||
| granule_mesh) | ||
| device_mesh = np.block(blocks.tolist()) | ||
| return device_mesh | ||
|
|
||
|
|
||
| class ShardingType(IntEnum): | ||
| # ShardingType enum ID maps to OpSharidng.Type (https://shorturl.at/pvAJX) | ||
| REPLICATED = 0 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this result respect the _create_device_mesh_for_nd_torus algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have confirmed this with the jax's mesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make the
ici_mesh_shap=(2, 2)? I think that can better show how the algorithm works?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed ici_mesh_shape