@@ -71,6 +71,9 @@ def get_logical_mesh(self):
7171 return self .device_ids .reshape (self .mesh_shape )
7272
7373
74+ # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4
75+
76+
7477class HybridMesh (Mesh ):
7578 """Creates a hybrid device mesh of devices connected with ICI and DCN networks.
7679 The shape of logical mesh should be ordered by increasing network-intensity
@@ -134,16 +137,57 @@ def _get_physical_tpu_mesh(self, devices: Sequence[Any]) -> np.ndarray:
134137 def _create_device_mesh_for_nd_torus (
135138 self , physical_mesh : np .ndarray ,
136139 mesh_shape : Sequence [int ]) -> Tuple [np .ndarray , List [Tuple [int , ...]]]:
140+ """Assigns logical parallelism axes to physical axes of an N-D torus network.
141+
142+ Given logical parallelism axes with sizes in `mesh_shape` and devices in an
143+ N-dimensional torus network represented by `physical_mesh`, maps each logical
144+ axis to one or more physical axes. Prefer to map more-performance-sensitive
145+ logical axes to larger numbers of physical axes to maximize the bandwidth
146+ available to them. Also prefer to assign logical axes to multiple physical
147+ axes of the same size (e.g., a 2D square) rather than multiple physical axes
148+ of different sizes when possible.
149+
150+ Note that this routine will never split a physical axis over more than one
151+ logical axis (which would reduce total usable bandwidth but may sometimes be
152+ desired anyway). As a result, it will error out in cases where this is
153+ necessary to produce a valid mapping.
154+
155+ Let's use a concrete example to explain the concepts and considerations.
156+
157+ As an example, suppose the logical mesh is [data, model], for data and model
158+ parallelism respectively. Also suppose that data parallelism is less
159+ performance sensitive than model parallelism. Consider a 3D TPU pod slice of
160+ shape 4x4x16, represented by a physical mesh of shape (4, 4, 16).
161+
162+ A TPU pod slice has equal bandwidth along all axes with wraparound links, but
163+ a 2D plane of size 4x4 may have faster XLA collective implementations than a
164+ non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want
165+ the more performance sensitive `model` axis to be mapped to the 4x4 XY plane.
166+
167+ Args:
168+ physical_mesh: a np.ndarray of devices in the shape of the N-D torus
169+ physical topology.
170+ mesh_shape: shape of the logical mesh (size of the various logical
171+ parallelism axes), with axes ordered by increasing network intensity.
172+
173+ Returns:
174+ An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with
175+ each logical parallelism axis mapped to one or more physical mesh axes.
176+ The axis assignment (a list of length num_logical_axes, whose elements
177+ are tuples representing physical axis indices).
178+ """
137179 # Remaining physical axes to be assigned to logical axes.
138180 assignable_physical_mesh = list (physical_mesh .shape )
139181 # Map each logical axis to a subset of physical axes.
140182 assignment : List [Tuple [int , ...]] = [() for _ in mesh_shape ]
141183 # Assign logical axes from highest network intensity to lowest.
142184 # `mesh_shape` is assumed to ordered by lowest network intensity first, so
143185 # reverse it first.
186+ # Assigns devices to 2D or 3D logical mesh.
144187 for logical_axis_index , logical_axis_size in reversed (
145188 list (enumerate (mesh_shape ))):
146189 for num_axes in range (3 , 0 , - 1 ):
190+ # map a combination of devices in physical axes to the logical axis.
147191 axes = itertools .combinations (assignable_physical_mesh , num_axes )
148192 indices = itertools .combinations (
149193 range (len (assignable_physical_mesh )), num_axes )
0 commit comments