Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Mar 8, 2024

Summary:
This pull request introduces _extract_backend_config to extrac the Mosaic payload from the custom call programmatically such that we don't need to copy & paste.

In order to run the test, JAX dependencies are added to the CI. The JAX version is using the nightly on the same day as our libtpu for the best compatibility. However, we need to figure out a way to update that automatically when we are updating our open-xla pin.

Test Plan:
python test/test_operations.py -v -k test_tpu_custom_call_pallas_extract_add_payload

@alanwaketan alanwaketan self-assigned this Mar 8, 2024
@alanwaketan
Copy link
Collaborator Author

I will add the CI change later.

@alanwaketan alanwaketan requested a review from qihqi March 8, 2024 02:28
run: |
pip install fsspec
pip install rich
pip install -U --pre jaxlib==0.4.25.dev20240213 -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add these dependencies to our setup.py? Maybe a pallas extras_require since AFAIK jax will still be optional. We can update the JAX version together with libtpu-nightly when we do pin updates

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I wasn't quite sure to make it default for development, but I guess why not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@will-cromar How are you going to add nightly dependencies to setup.py? I tried:

ptxla@t1v-n-94bbce5b-w-0:/workspaces/work/pytorch/xla$ git diff setup.py 
diff --git a/setup.py b/setup.py
index e4fdaebee..66a1eb321 100644
--- a/setup.py
+++ b/setup.py
@@ -288,6 +288,8 @@ setup(
         # importlib.metadata backport required for PJRT plugin discovery prior
         # to Python 3.10
         'importlib_metadata>=4.6;python_version<"3.10"',
+        'jaxlib==0.4.25.dev20240213 @ https://storage.googleapis.com/jax-releases/jaxl
ib_nightly_releases.html',
+        'jax==0.4.25.dev20240213 @ https://storage.googleapis.com/jax-releases/jax_nig
htly_releases.html',
     ],
     package_data={
         'torch_xla': ['lib/*.so*',],

but that doesn't work...

error in torch_xla setup command: 'install_requires' must be a string or list of strings containing valid project/version requirement specifiers; Parse error at "'@ https:'": Expected string_end

Copy link
Collaborator

@will-cromar will-cromar Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it to extras_require so it's not required by default:

    extras_require={
        # On Cloud TPU VM install with:
        # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
        'tpu': [f'libtpu-nightly=={_libtpu_version}'],
        # On nightly, install libtpu with `pip install torch_xla[tpuvm]`
        # Remove from release branches since this is not allowed by PyPI.
        'tpuvm': [f'libtpu-nightly @ {_libtpu_storage_path}'],
        'pallas': ['jaxlib==0.4.25.dev20240213', 'jax==0.4.25.dev20240213']
    },

To uploadable to PyPI, you can't bake in a link to another index. The user will have to do pip install torch_xla[pallas] -f path/to/index like they do for TPU.

.... unless you can just rely on a stable release of jax that's already on PyPI. If JAX is not loading libtpu (I suspect it does not, otherwise we would crash), the JAX build date does not have to match libtpu-nightly. Maybe it's better to investigate that option before we make the feature stable so we can get this merged soon.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, JAX will need Mosaic which is in the libtpu to generate the payload. So we need the nightly. Thanks, Will. Will do it in this way.

self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_extract_add_payload(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the custom kernel tests should be in a separate test file? IMO test_operations shouldn't depend on anything other than our basic requirements

  • torchvision is imported but we don't actually use it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I agree that pallas related stuff is better to be putted in a new test file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just too lazzzy...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol the problem of the pallas is it introcued the JAX dependency. We might have to disable the test on TPUCI sometimes, it is better not to disable the whole test_operations.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a new file.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 8, 2024

Do we need this to be merged into 2.3 release? This intorduce a JAX dependency which might cause issues. I would prefer to leave this in nightly until things stabilized a bit.

@alanwaketan
Copy link
Collaborator Author

Do we need this to be merged into 2.3 release? This intorduce a JAX dependency which might cause issues. I would prefer to leave this in nightly until things stabilized a bit.

Yea, it needs to be in 2.3. I guess we can keep jax as an extra dependencies for those who needs pallas.

@alanwaketan alanwaketan force-pushed the alanwaketan/pallas_jax branch from 7ec4903 to 769fa66 Compare March 8, 2024 20:30
@alanwaketan
Copy link
Collaborator Author

@will-cromar I fixed the jax thing. Hopefully, this time it works.

@alanwaketan
Copy link
Collaborator Author

Thanks Jack for approving!

# def add_vectors_kernel(x_ref, y_ref, o_ref):
# x, y = x_ref[...], y_ref[...]
# o_ref[...] = x + y
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\", \"needs_layout_passes\": true}}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to ask if this base64 (I assume) blob means anything human-readable you could encode at test time. Then I decoded it:

$ echo TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA= | base64 --decode
ML�RMLIR18.0.0git+
                  	
K	iM

 3



  e









+-')7GI
       35
	ffine_map<(d0) -> (d0)>!#%')/	1+9;-%?A/C	E1	#tpu.memory_space<vmem>K!'!�'


             	=	



v3�
   �
#!)'C

	
        builtinfunctpuarithvectormodulereturnconstantaddiloadstore/home/jwtan/pallas/pallas_add.pyadd_vectors_kerneldimension_semanticsfunction_typescalar_prefetchscratch_operandssym_namemainvalue/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 8)], []),)), (8,), ())], []),))]add_vectors<module>/add/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 8)], []),)), (8,), ())], []),))]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can secretly decode it in Google 3 haha...

@alanwaketan
Copy link
Collaborator Author

The dynamo issue in CI is unrelated...

Summary:
This pull request introduces _extract_backend_config to extrac the Mosaic payload
from the custom call programmatically such that we don't need to copy & paste.

In order to run the test, JAX dependencies are added to the CI.

Test Plan:
python test/test_operations.py -v -k test_tpu_custom_call_pallas_extract_add_payload
@alanwaketan alanwaketan force-pushed the alanwaketan/pallas_jax branch from 0540fb8 to 18841a4 Compare March 8, 2024 23:23
@alanwaketan alanwaketan merged commit 3706791 into master Mar 9, 2024
@alanwaketan alanwaketan deleted the alanwaketan/pallas_jax branch March 9, 2024 18:48
@vanbasten23 vanbasten23 restored the alanwaketan/pallas_jax branch March 11, 2024 23:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants