Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Jun 10, 2024

Summary:
This PR is to add experimental support of cc ops in manual sharding zones. This one adds reduce-scatter as the initial step. The key here is to add channel_id, replica_groups, and use_global_device_ids in the lowering.

Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_spmd_reduce_scatter

@alanwaketan alanwaketan requested review from JackCaoG and jonb377 June 10, 2024 20:48
@alanwaketan alanwaketan self-assigned this Jun 10, 2024
@alanwaketan
Copy link
Collaborator Author

Why the hell do we run spmd tests on cpu?

@JackCaoG
Copy link
Collaborator

You can skip your test for CPU. For the most sharding related test they actually passed on CPU. CPU test is the only CI that upstream runs against us.

return result_tuple;
});
m.def(
"_xla_spmd_reduce_scatter",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the only difference is this one does not have token?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's one difference. Others are mentioned in the description.

@alanwaketan
Copy link
Collaborator Author

The GPU test failure doesn't seem to be related.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants