Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a70d921
OpenXLA pin update to b166243711f71b0a55daa1eda36b1dc745886784
yeounoh Feb 13, 2024
dd23814
Drop stablehlo_quant_seralization.diff gpu_hanging.diff
yeounoh Feb 13, 2024
2aeb5c3
* Add auto-sharding flag to `use_spmd`
yeounoh Oct 25, 2023
197206b
debug with local openxla
yeounoh Nov 10, 2023
2e4b511
Add debugging stubs
yeounoh Nov 14, 2023
580d527
Add ShadingUtil::ReshardParameters
yeounoh Nov 14, 2023
8ca228b
Use xla::OpSharding::UNKNOWN for implicit replication
yeounoh Nov 17, 2023
c17b963
Move to openxla head
yeounoh Nov 28, 2023
a061464
Resharding parameters should update the device data node.
yeounoh Dec 13, 2023
c6a5ad6
Group parameter resharding computations
yeounoh Dec 18, 2023
5c6ae88
Enable XLA_AUTO_SPMD_MESH for auto-sharding mesh_shape
yeounoh Dec 21, 2023
273617a
Merge master
yeounoh Jan 23, 2024
e171e1f
Add debugging probs
yeounoh Jan 23, 2024
18507a0
Avoid syncing auto-generated node sharding back to data nodes.
yeounoh Jan 23, 2024
d7d6529
OuputShardingPropagation can sync sharding from IR node after auto-sh…
yeounoh Jan 24, 2024
e2ab353
* Set default XLATensor sharding to UNKNOWN
yeounoh Feb 7, 2024
fc17109
Handle tuple shapes in resharding
yeounoh Feb 8, 2024
9312c41
Disable tuple in resharding
yeounoh Feb 9, 2024
18bea46
Remove debugging stubs
yeounoh Feb 9, 2024
4fb891a
Sync after resharding data
yeounoh Feb 13, 2024
2b3494e
* Move OpenXLA to 1fc74e9890cd7785945fa39de9a3b54659f3e792, to apply …
yeounoh Feb 14, 2024
ee0a198
Move XLA pin to 075d25e0c19e4e455ba0a2bcc432d581128e66aa
yeounoh Feb 16, 2024
ecc1760
* Skip resharding if not using SPMD.
yeounoh Feb 29, 2024
e838e32
Debug wrapping
yeounoh Mar 2, 2024
71ae2c0
Build openxla from local
yeounoh Mar 3, 2024
03af881
Reshard UNKNOWN to REPLICATED
yeounoh Mar 4, 2024
8e29449
Verify the use of UNKNOWN sharding type as auto-sharding pass does no…
yeounoh Mar 7, 2024
577c178
Unittests for simple linear training
yeounoh Mar 11, 2024
5df0e95
use UNKNOWN sharding type for device data
yeounoh Mar 11, 2024
a948778
Enable unittests for cpu
yeounoh Mar 11, 2024
edfc8ef
*Enable test_xla_auto_sharding.py
yeounoh Mar 12, 2024
1cd02b9
Introduce torch_xla.distributed.spmd.auto_policy to enable auto-sharding
yeounoh Mar 13, 2024
c42e77f
Add _xla_get_auto_sharding
yeounoh Mar 13, 2024
8efb664
Fix errors after rebasing
yeounoh Mar 13, 2024
d30a721
Test auto_policy & register SPMD to device mapper
yeounoh Mar 13, 2024
42423d3
Use DTensor API directly
yeounoh Mar 14, 2024
7899faf
Linter & refactor
yeounoh Mar 14, 2024
542ae10
Update spmd.md doc
yeounoh Mar 14, 2024
557161d
remove aten bridge change
yeounoh Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,33 @@ generated_table = visualize_sharding(sharding, use_color=False)

You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`.

### Auto-Sharding
We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RFC](https://github.com/pytorch/xla/issues/6322). This is an experimental feature in `r2.3` and `nightly`, that supports `XLA:TPU` and a single TPUVM host.

PyTorch/XLA auto-sharding can be enabled by one of the following:
- Setting envvar `XLA_SPMD_AUTO=1`
- Calling the SPMD API in the beginning of your code:
```python
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
```
- Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`:
```python
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
```

Optionally, one can set the following options/env-vars to control the behvaior of
the XLA-based auto-sharding pass:
- `XLA_AUTO_USE_GROUP_SHARDING`: group resharding of the parameters. Set by default.
- `XLA_AUTO_SPMD_MESH`: logical mesh shape to be used for auto-sharding. For example,
`XLA_AUTO_SPMD_MESH=2,2` corresponds to a 2-by-2 mesh with 4 global devices. If unset,
a default device mesh shape of `num_devices,1` will be used.
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_dtensor_integration2.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this on TPU CI as well or it is ok to leave out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh i think it's ok to leave out. Want to run this sanity check on TPU!

run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
1 change: 1 addition & 0 deletions test/spmd/args_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def parse_common_options(datadir=None,
parser.add_argument('--async_closures', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--auto_spmd', action='store_true')
if opts:
for name, aopts in opts:
parser.add_argument(name, **aopts)
Expand Down
73 changes: 38 additions & 35 deletions test/spmd/test_dtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import DeviceMesh, Shard
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
distribute_module)
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import xla_distribute_tensor, xla_distribute_module
from torch_xla.distributed.spmd import auto_policy

import unittest

Expand All @@ -19,7 +21,6 @@ class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_distribute_tensor(self):
Expand All @@ -33,8 +34,7 @@ def test_xla_distribute_tensor(self):
3,
requires_grad=requires_grad,
device=xm.xla_device())
dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh,
shard_spec)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
assert len(dist_tensor.sharding_spec) > 0
Expand All @@ -47,65 +47,68 @@ def test_xla_distribute_tensor(self):
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

def test_xla_distribute_module(self):
def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = xla_distribute_tensor(param, mesh, shard_spec)

sharded_model = xla_distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "")
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "")

sharded_model.train()
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
for _ in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))

def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
def test_xla_distribute_module(self):
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = distribute_tensor(param, mesh, shard_spec)

sharded_model = distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "")
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "")

sharded_model.train()
optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
for _ in range(3):
optimizer.zero_grad()
output = model(data)
output = sharded_model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
# Should run with SPMD mode, ExecuteReplicated.
self.assertTrue(met.counter_value("ExecuteReplicated") > 0)
self.assertTrue(met.counter_value("ExecuteComputation") is None)


if __name__ == '__main__':
Expand Down
58 changes: 58 additions & 0 deletions test/spmd/test_dtensor_integration2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import sys

import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor,
distribute_module)
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import auto_policy

import unittest

import test_xla_sharding_base


# This integration test passes when run independently.
class DTensorIntegrationTest2(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU device.")
def test_xla_distribute_module_auto(self):
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Use torch_xla.distributed.spmd.auto_policy to enable auto-sharding;
# Currently, model should be loaded to xla device via distribute_module.
model = self.SimpleLinear()
sharded_model = distribute_module(model, device_mesh, auto_policy)
sharded_model.train()
self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding())

optimizer = optim.SGD(sharded_model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for _ in range(5):
optimizer.zero_grad()
output = sharded_model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Should compile with auto-sharding, we expect up to 3 times
cnt = met.counter_value("CompileWithAutoSharding")
self.assertTrue((cnt is not None) and (cnt <= 3))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 0 additions & 1 deletion test/spmd/test_dynamo_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class DynamoSpmdInferenceTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_dynamo_spmd_basic(self):
Expand Down
1 change: 0 additions & 1 deletion test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class SpmdGraphDumpTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_dump_with_output_sharding(self):
Expand Down
2 changes: 2 additions & 0 deletions test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
import torch_xla.test.test_utils as test_utils
import torch_xla.distributed.spmd as xs

xr.use_spmd(auto=FLAGS.auto_spmd)

DEFAULT_KWARGS = dict(
batch_size=128,
test_set_batch_size=64,
Expand Down
2 changes: 2 additions & 0 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
FLAGS = args_parse.parse_common_options(
batch_size=128, num_epochs=1, opts=MODEL_OPTS.items())

xr.use_spmd(auto=FLAGS.auto_spmd)


class SimpleLinear(nn.Module):

Expand Down
74 changes: 74 additions & 0 deletions test/spmd/test_xla_auto_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import copy

import unittest
from unittest.mock import patch
import math
import numpy as np
import os
import sys

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
import test_xla_sharding_base

import torch_xla.core.xla_env_vars as xenv
import torch_xla.utils.utils as xu
from torch_xla._internal import tpu


class XlaAutoShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd(auto=True)
super().setUpClass()

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU & CPU backends.")
def test_matmul(self):
met.clear_counters()
t1 = torch.ones(64, 128)
t2 = torch.ones(128, 256)
t3 = (t1 @ t2).sum()

xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
xt3 = (xt1 @ xt2).sum()
xm.mark_step()
self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1)
self.assertTrue(torch.allclose(t3, xt3.cpu()))

@unittest.skipUnless(xr.device_type() in ["TPU", "CPU"],
"Auto-sharding currently supports TPU & CPU backends.")
def test_simple_linear_training(self):
met.clear_counters()

model = self.SimpleLinear().to(xm.xla_device())
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(5):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
cnt = met.counter_value("CompileWithAutoSharding")
self.assertTrue((cnt is not None) and (cnt <= 3))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 0 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def _get_sharded_model(self, mesh_shape=None):
Expand Down
3 changes: 0 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_sharded_tensor(self):
Expand All @@ -38,8 +37,6 @@ def test_xla_sharded_tensor(self):
device=xm.xla_device())
xst1 = xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)),
partition_spec)

# TODO(244003536) add more tests for XLAShardedTensror.
self.assertTrue(isinstance(xst1, XLAShardedTensor))

def test_xla_sharded_tensor_repr(self):
Expand Down
8 changes: 8 additions & 0 deletions test/spmd/test_xla_sharding_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
import numpy as np

Expand Down Expand Up @@ -31,6 +32,13 @@ def forward(self, x):
def setUpClass(cls):
cls.n_devices = xr.global_runtime_device_count()
cls.device_ids = np.array(range(cls.n_devices))
xr.use_spmd()

@classmethod
def tearDownClass(cls):
del os.environ['XLA_USE_SPMD']
if 'XLA_AUTO_SPMD' in os.environ:
del os.environ['XLA_AUTO_SPMD']
Comment on lines +39 to +41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO here to switch to api instead of env var eventually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, will add that under runtine.use_spmd. By the way, I think we should continue support both if possible -- envvar is the only way to set something system-wide, which is probably why XLA rely on this mechanism.


def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None):
assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]'
Expand Down
Loading