Skip to content

Conversation

@yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Mar 12, 2024

This implemented a PoC prototype on XLA:TPU, as described in #6322

PyTorch/XLA auto-sharding can be enabled by one of the following:

  • Setting envvar XLA_SPMD_AUTO=1
  • Calling the SPMD API in the beginning of your code:
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • Calling pytorch.distributed._tensor.distribute_module with auto-policy and xla:
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

Some notable limitations that we will address in follow-ups:

  • XLA:GPU is not supported
  • TPU pod is not supported

cc @baoleai

@yeounoh yeounoh added the distributed SPMD and other distributed things. label Mar 12, 2024
@yeounoh yeounoh requested a review from JackCaoG March 12, 2024 00:21
@yeounoh yeounoh self-assigned this Mar 12, 2024
@yeounoh yeounoh marked this pull request as draft March 12, 2024 00:22
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 2 times, most recently from 126ceee to 4d568ef Compare March 12, 2024 00:25
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 2 times, most recently from 6ca8f97 to d6dc442 Compare March 12, 2024 00:38
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 12 times, most recently from 303b239 to d3c1d70 Compare March 12, 2024 07:34
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 4 times, most recently from 968bca4 to eadcae6 Compare March 14, 2024 18:15
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_dtensor_integration2.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this on TPU CI as well or it is ok to leave out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh i think it's ok to leave out. Want to run this sanity check on TPU!

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to adjust remaining comments in a follow up [r

@yeounoh yeounoh merged commit 370089a into master Mar 14, 2024
yeounoh added a commit that referenced this pull request Mar 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backport_2.3 distributed SPMD and other distributed things.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants