Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ def test_mark_sharding_not_ordered_partial_4d(self):
actual = (xt1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
'At least 4 devices required')
def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self):
ct1 = torch.randn(16, 16, device='cpu')
ct2 = torch.randn(16, 16, device='cpu')
Expand All @@ -382,15 +384,16 @@ def test_mark_sharding_not_ordered_2d_tensor_3d_mesh(self):
t1 = ct1.to(xm.xla_device())
t2 = ct2.to(xm.xla_device())
mesh = self._get_mesh((1, self.n_devices, 1))
mesh = self._get_mesh((1, self.n_devices // 2, 2))
# sharding spec here is not ordered.
xt1 = xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
if self.n_devices > 1:
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt1.global_tensor])
sharding_annotation = 'sharding={devices=[1,1,%d]%s}' % (
self.n_devices, ','.join(
[str(d) for d in mesh.get_logical_mesh().flatten()]))
self.assertIn(sharding_annotation, hlo)
actual = (xt1 + t2).cpu()
xs.mark_sharding(t1, mesh, partition_spec=(1, 0))
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(t1)
devices = f'[{self.n_devices // 2},1,2]' + ','.join(
str(x) for x in range(self.n_devices))
expected_spec = f'{{devices={devices} last_tile_dim_replicate}}'
self.assertEqual(sharding_spec, expected_spec)

actual = (t1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

def test_partial_replication_addmm(self):
Expand Down Expand Up @@ -631,22 +634,24 @@ def test_xla_sharded_hlo_dump(self):
# scalar 5 should be replicated
self.assertIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo)

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
'At least 4 devices required')
def test_2d_tensor_3d_mesh(self):
ct1 = torch.randn(16, 16, device='cpu')
ct2 = torch.randn(16, 16, device='cpu')
ct1 = torch.randn(16, 16)
ct2 = torch.randn(16, 16)
expected = ct1 + ct2

t1 = ct1.to(xm.xla_device())
t2 = ct2.to(xm.xla_device())
mesh = self._get_mesh((1, self.n_devices, 1))
t1 = xs.mark_sharding(t1, mesh, partition_spec=(1, 2))
if self.n_devices > 1:
hlo = torch_xla._XLAC._get_xla_tensors_hlo([t1.global_tensor])
# expected string in hlo %param = f32[1,4,16]{2,1,0:T(4,128)} parameter(0), sharding={devices=[1,4,1]0,2,1,3}
sharding_annotation = 'sharding={devices=[1,%d,1]%s}' % (
self.n_devices, ','.join(
[str(d) for d in mesh.get_logical_mesh().flatten()]))
self.assertIn(sharding_annotation, hlo)
mesh = self._get_mesh((1, self.n_devices // 2, 2))
xs.mark_sharding(t1, mesh, partition_spec=(0, 1))

sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(t1)
devices = f'[1,{self.n_devices // 2},2]' + ','.join(
str(x) for x in range(self.n_devices))
expected_spec = f'{{devices={devices} last_tile_dim_replicate}}'
self.assertEqual(sharding_spec, expected_spec)

actual = (t1 + t2).cpu()
self.assertTrue(torch.allclose(expected, actual))

Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1375,13 +1375,17 @@ void InitXlaModuleBindings(py::module m) {
m.def("_xla_mark_sharding",
[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
int sharding_type, bool tensor_rank_less_than_mesh) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice()) << "Please set `XLA_USE_SPMD=1`";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xla::OpSharding sharding = ShardingUtil::CreateOpSharding(
tile_assignment, group_assignment, replication_groups,
ShardingUtil::ShardingType(sharding_type));
if (tensor_rank_less_than_mesh) {
// Replicate the lower-rank tensor along the last mesh dimension.
sharding.set_replicate_on_last_tile_dim(true);
}
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding,
MakeShapeWithDeviceLayout(
Expand Down
31 changes: 11 additions & 20 deletions torch_xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,35 +412,26 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

# check for sharding 2D tensor on a 3D mesh
original_shape = tuple(t.shape)
# number of dims to expand on tensor
tensor_expand = 0
if tensor_expand < len(mesh.get_logical_mesh().shape) - len(partition_spec):
tensor_expand = len(mesh.get_logical_mesh().shape) - len(partition_spec)
partition_spec = (None,) * tensor_expand + partition_spec
shape = (1,) * tensor_expand + (*original_shape,)
t = t.expand(shape)

tile_assignment = _get_tile_assignment(mesh, partition_spec)
tensor_rank_less_than_mesh = len(t.shape) < len(mesh.get_logical_mesh().shape)
if tensor_rank_less_than_mesh:
assert len(mesh.get_logical_mesh().shape) == len(
t.shape) + 1, 'Tensor rank must be equal to or one less than mesh rank'
tile_assignment = _get_tile_assignment(mesh, partition_spec + (None,))
else:
tile_assignment = _get_tile_assignment(mesh, partition_spec)
sharding_type = _get_sharding_type(partition_spec, num_devices)
group_assignment, replication_groups = _get_group_assignment(
sharding_type, partition_spec, tile_assignment)

def tensor_squeeze(t, tensor_expand):
if tensor_expand:
t = torch.squeeze(t, dim=tuple(range(tensor_expand)))
return t

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, tile_assignment,
group_assignment, replication_groups,
int(sharding_type))
t = tensor_squeeze(t, tensor_expand)
int(sharding_type),
tensor_rank_less_than_mesh)
return t
torch_xla._XLAC._xla_mark_sharding(t, tile_assignment, group_assignment,
replication_groups, int(sharding_type))
t = tensor_squeeze(t, tensor_expand)
replication_groups, int(sharding_type),
tensor_rank_less_than_mesh)
return XLAShardedTensor(t)


Expand Down