-
Notifications
You must be signed in to change notification settings - Fork 561
[SPMD] auto-sharding PoC #6719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] auto-sharding PoC #6719
Changes from all commits
a70d921
dd23814
2aeb5c3
197206b
2e4b511
580d527
8ca228b
c17b963
a061464
c6a5ad6
5c6ae88
273617a
e171e1f
18507a0
d7d6529
e2ab353
fc17109
9312c41
18bea46
4fb891a
2b3494e
ee0a198
ecc1760
e838e32
71ae2c0
03af881
8e29449
577c178
5df0e95
a948778
edfc8ef
1cd02b9
c42e77f
8efb664
d30a721
42423d3
7899faf
542ae10
557161d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import os | ||
| import sys | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| import torch.optim as optim | ||
| from torch.distributed._tensor import (DeviceMesh, Shard, distribute_tensor, | ||
| distribute_module) | ||
| import torch_xla | ||
| import torch_xla.debug.metrics as met | ||
| import torch_xla.runtime as xr | ||
| import torch_xla.core.xla_model as xm | ||
| from torch_xla.distributed.spmd import auto_policy | ||
|
|
||
| import unittest | ||
|
|
||
| import test_xla_sharding_base | ||
|
|
||
|
|
||
| # This integration test passes when run independently. | ||
| class DTensorIntegrationTest2(test_xla_sharding_base.XlaShardingTest): | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| super().setUpClass() | ||
|
|
||
| @unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], | ||
| "Auto-sharding currently supports TPU device.") | ||
| def test_xla_distribute_module_auto(self): | ||
| device_count = xr.global_runtime_device_count() | ||
| device_mesh = DeviceMesh("xla", list(range(device_count))) | ||
|
|
||
| # Use torch_xla.distributed.spmd.auto_policy to enable auto-sharding; | ||
| # Currently, model should be loaded to xla device via distribute_module. | ||
| model = self.SimpleLinear() | ||
| sharded_model = distribute_module(model, device_mesh, auto_policy) | ||
| sharded_model.train() | ||
| self.assertTrue(torch_xla._XLAC._xla_get_auto_sharding()) | ||
|
|
||
| optimizer = optim.SGD(sharded_model.parameters(), lr=0.1) | ||
| data = torch.randn(128, 128).to(xm.xla_device()) | ||
| target = torch.zeros(128).to(xm.xla_device()) | ||
| loss_fn = nn.CrossEntropyLoss() | ||
| for _ in range(5): | ||
| optimizer.zero_grad() | ||
| output = sharded_model(data) | ||
| loss = loss_fn(output, target) | ||
| loss.backward() | ||
| optimizer.step() | ||
| xm.mark_step() | ||
| # Should compile with auto-sharding, we expect up to 3 times | ||
| cnt = met.counter_value("CompileWithAutoSharding") | ||
| self.assertTrue((cnt is not None) and (cnt <= 3)) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| test = unittest.main() | ||
| sys.exit(0 if test.result.wasSuccessful() else 1) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| import copy | ||
|
|
||
| import unittest | ||
| from unittest.mock import patch | ||
| import math | ||
| import numpy as np | ||
| import os | ||
| import sys | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| import torch.nn.functional as F | ||
| import torch.optim as optim | ||
| import torch_xla | ||
| import torch_xla.debug.metrics as met | ||
| import torch_xla.runtime as xr | ||
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.debug.metrics as met | ||
| import torch_xla.distributed.spmd as xs | ||
| from torch_xla.distributed.spmd import XLAShardedTensor | ||
| import test_xla_sharding_base | ||
|
|
||
| import torch_xla.core.xla_env_vars as xenv | ||
| import torch_xla.utils.utils as xu | ||
| from torch_xla._internal import tpu | ||
|
|
||
|
|
||
| class XlaAutoShardingTest(test_xla_sharding_base.XlaShardingTest): | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| xr.use_spmd(auto=True) | ||
| super().setUpClass() | ||
|
|
||
| @unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], | ||
| "Auto-sharding currently supports TPU & CPU backends.") | ||
| def test_matmul(self): | ||
| met.clear_counters() | ||
| t1 = torch.ones(64, 128) | ||
| t2 = torch.ones(128, 256) | ||
| t3 = (t1 @ t2).sum() | ||
|
|
||
| xt1 = t1.to(xm.xla_device()) | ||
| xt2 = t2.to(xm.xla_device()) | ||
| xt3 = (xt1 @ xt2).sum() | ||
| xm.mark_step() | ||
| self.assertEqual(met.counter_value("CompileWithAutoSharding"), 1) | ||
| self.assertTrue(torch.allclose(t3, xt3.cpu())) | ||
|
|
||
| @unittest.skipUnless(xr.device_type() in ["TPU", "CPU"], | ||
| "Auto-sharding currently supports TPU & CPU backends.") | ||
| def test_simple_linear_training(self): | ||
| met.clear_counters() | ||
|
|
||
| model = self.SimpleLinear().to(xm.xla_device()) | ||
| model.train() | ||
| optimizer = optim.SGD(model.parameters(), lr=0.1) | ||
| data = torch.randn(128, 128).to(xm.xla_device()) | ||
| target = torch.zeros(128).to(xm.xla_device()) | ||
| loss_fn = nn.CrossEntropyLoss() | ||
| for i in range(5): | ||
| optimizer.zero_grad() | ||
| output = model(data) | ||
| loss = loss_fn(output, target) | ||
| loss.backward() | ||
| optimizer.step() | ||
| xm.mark_step() | ||
| cnt = met.counter_value("CompileWithAutoSharding") | ||
| self.assertTrue((cnt is not None) and (cnt <= 3)) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| test = unittest.main() | ||
| sys.exit(0 if test.result.wasSuccessful() else 1) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import os | ||
| import unittest | ||
| import numpy as np | ||
|
|
||
|
|
@@ -31,6 +32,13 @@ def forward(self, x): | |
| def setUpClass(cls): | ||
| cls.n_devices = xr.global_runtime_device_count() | ||
| cls.device_ids = np.array(range(cls.n_devices)) | ||
| xr.use_spmd() | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| del os.environ['XLA_USE_SPMD'] | ||
| if 'XLA_AUTO_SPMD' in os.environ: | ||
| del os.environ['XLA_AUTO_SPMD'] | ||
|
Comment on lines
+39
to
+41
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a TODO here to switch to api instead of env var eventually?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ack, will add that under |
||
|
|
||
| def _get_mesh(self, mesh_shape, device_ids=None, axis_names=None): | ||
| assert type(mesh_shape) is tuple, 'mesh_shape must be Tuple[int]' | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this on TPU CI as well or it is ok to leave out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhh i think it's ok to leave out. Want to run this sanity check on TPU!