-
Notifications
You must be signed in to change notification settings - Fork 561
[SPMD] Support reduce-scatter in manual sharding #7231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Why the hell do we run spmd tests on cpu? |
|
You can skip your test for CPU. For the most sharding related test they actually passed on CPU. CPU test is the only CI that upstream runs against us. |
| return result_tuple; | ||
| }); | ||
| m.def( | ||
| "_xla_spmd_reduce_scatter", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the only difference is this one does not have token?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's one difference. Others are mentioned in the description.
|
The GPU test failure doesn't seem to be related. |
Summary:
This PR is to add experimental support of cc ops in manual sharding zones. This one adds reduce-scatter as the initial step. The key here is to add channel_id, replica_groups, and use_global_device_ids in the lowering.
Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_spmd_reduce_scatter