-
Notifications
You must be signed in to change notification settings - Fork 561
[Pallas] Support programmatically extracting the payload #6696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I will add the CI change later. |
.github/workflows/tpu_ci.yml
Outdated
| run: | | ||
| pip install fsspec | ||
| pip install rich | ||
| pip install -U --pre jaxlib==0.4.25.dev20240213 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add these dependencies to our setup.py? Maybe a pallas extras_require since AFAIK jax will still be optional. We can update the JAX version together with libtpu-nightly when we do pin updates
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I wasn't quite sure to make it default for development, but I guess why not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@will-cromar How are you going to add nightly dependencies to setup.py? I tried:
ptxla@t1v-n-94bbce5b-w-0:/workspaces/work/pytorch/xla$ git diff setup.py
diff --git a/setup.py b/setup.py
index e4fdaebee..66a1eb321 100644
--- a/setup.py
+++ b/setup.py
@@ -288,6 +288,8 @@ setup(
# importlib.metadata backport required for PJRT plugin discovery prior
# to Python 3.10
'importlib_metadata>=4.6;python_version<"3.10"',
+ 'jaxlib==0.4.25.dev20240213 @ https://storage.googleapis.com/jax-releases/jaxl
ib_nightly_releases.html',
+ 'jax==0.4.25.dev20240213 @ https://storage.googleapis.com/jax-releases/jax_nig
htly_releases.html',
],
package_data={
'torch_xla': ['lib/*.so*',],
but that doesn't work...
error in torch_xla setup command: 'install_requires' must be a string or list of strings containing valid project/version requirement specifiers; Parse error at "'@ https:'": Expected string_end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add it to extras_require so it's not required by default:
extras_require={
# On Cloud TPU VM install with:
# pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
'tpu': [f'libtpu-nightly=={_libtpu_version}'],
# On nightly, install libtpu with `pip install torch_xla[tpuvm]`
# Remove from release branches since this is not allowed by PyPI.
'tpuvm': [f'libtpu-nightly @ {_libtpu_storage_path}'],
'pallas': ['jaxlib==0.4.25.dev20240213', 'jax==0.4.25.dev20240213']
},
To uploadable to PyPI, you can't bake in a link to another index. The user will have to do pip install torch_xla[pallas] -f path/to/index like they do for TPU.
.... unless you can just rely on a stable release of jax that's already on PyPI. If JAX is not loading libtpu (I suspect it does not, otherwise we would crash), the JAX build date does not have to match libtpu-nightly. Maybe it's better to investigate that option before we make the feature stable so we can get this merged soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, JAX will need Mosaic which is in the libtpu to generate the payload. So we need the nightly. Thanks, Will. Will do it in this way.
test/test_operations.py
Outdated
| self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu())) | ||
|
|
||
| @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") | ||
| def test_tpu_custom_call_pallas_extract_add_payload(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the custom kernel tests should be in a separate test file? IMO test_operations shouldn't depend on anything other than our basic requirements
torchvisionis imported but we don't actually use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea I agree that pallas related stuff is better to be putted in a new test file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just too lazzzy...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol the problem of the pallas is it introcued the JAX dependency. We might have to disable the test on TPUCI sometimes, it is better not to disable the whole test_operations.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to a new file.
|
Do we need this to be merged into 2.3 release? This intorduce a JAX dependency which might cause issues. I would prefer to leave this in nightly until things stabilized a bit. |
Yea, it needs to be in 2.3. I guess we can keep jax as an extra dependencies for those who needs pallas. |
7ec4903 to
769fa66
Compare
|
@will-cromar I fixed the jax thing. Hopefully, this time it works. |
|
Thanks Jack for approving! |
| # def add_vectors_kernel(x_ref, y_ref, o_ref): | ||
| # x, y = x_ref[...], y_ref[...] | ||
| # o_ref[...] = x + y | ||
| payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\", \"needs_layout_passes\": true}}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to ask if this base64 (I assume) blob means anything human-readable you could encode at test time. Then I decoded it:
$ echo TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA= | base64 --decode
ML�RMLIR18.0.0git+
K iM
3
e
+-')7GI
35
ffine_map<(d0) -> (d0)>!#%')/ 1+9;-%?A/C E1 #tpu.memory_space<vmem>K!'!�'
=
v3�
�
#!)'C
builtinfunctpuarithvectormodulereturnconstantaddiloadstore/home/jwtan/pallas/pallas_add.pyadd_vectors_kerneldimension_semanticsfunction_typescalar_prefetchscratch_operandssym_namemainvalue/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 8)], []),)), (8,), ())], []),))]add_vectors<module>/add/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 8)], []),)), (8,), ())], []),))]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can secretly decode it in Google 3 haha...
|
The dynamo issue in CI is unrelated... |
Summary: This pull request introduces _extract_backend_config to extrac the Mosaic payload from the custom call programmatically such that we don't need to copy & paste. In order to run the test, JAX dependencies are added to the CI. Test Plan: python test/test_operations.py -v -k test_tpu_custom_call_pallas_extract_add_payload
0540fb8 to
18841a4
Compare
Summary:
This pull request introduces _extract_backend_config to extrac the Mosaic payload from the custom call programmatically such that we don't need to copy & paste.
In order to run the test, JAX dependencies are added to the CI. The JAX version is using the nightly on the same day as our libtpu for the best compatibility. However, we need to figure out a way to update that automatically when we are updating our open-xla pin.
Test Plan:
python test/test_operations.py -v -k test_tpu_custom_call_pallas_extract_add_payload