-
Notifications
You must be signed in to change notification settings - Fork 562
AMP for TPUs v3 #5161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
AMP for TPUs v3 #5161
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
0df4743
tpu amp
cowanmeg 7997da0
Add torch pin
cowanmeg 6ac8262
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg 042a631
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg 401eaf6
updates
cowanmeg 041fce2
Clean up
cowanmeg 0aa5011
Delete test_fsdp_auto_wrap_amp.py
cowanmeg 2157e2c
updates
cowanmeg e247b43
Update autocast key to pick between Cuda and Xla. Unit tests
cowanmeg 7d215d1
lint
cowanmeg 661759a
moving code from pt to ptxla
cowanmeg e339506
fixes
cowanmeg 03da7ec
lint
cowanmeg e4a0daa
move autocast+test from common.sh to run_tests.sh
cowanmeg ff8f10b
updates
cowanmeg a0437a5
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg 410912e
updates with pytorch
cowanmeg aabbd12
lint
cowanmeg 1ce912c
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg dfe1f27
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg 5275334
build autocast_mode
cowanmeg 0652dd6
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg 089f296
Merge branch 'master' of https://github.com/cowanmeg/xla into amp
cowanmeg e1243d8
Merge branch 'master' of https://github.com/pytorch/xla into amp
cowanmeg 85d607a
Merge branch 'master' of https://github.com/pytorch/xla into amp
cowanmeg 9241b0a
Merge branch 'master' of https://github.com/pytorch/xla into amp
cowanmeg 643691c
Disable bazel remote cache
cowanmeg 4eba81f
experiment with no new files
cowanmeg ffac43f
revert back
cowanmeg 4097324
lint
cowanmeg c10b148
Add autocast test to tpu ci
cowanmeg 9ce63ab
Merge branch 'amp' of https://github.com/pytorch/xla into amp
cowanmeg 73392fc
Merge branch 'master' into amp
cowanmeg 4ade2cc
Delete .torch_pin
cowanmeg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| from .autocast_mode import autocast, custom_fwd, custom_bwd # noqa: F401 | ||
| from .autocast_mode import autocast # noqa: F401 | ||
| from .grad_scaler import GradScaler # noqa: F401 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,40 @@ | ||
| import torch | ||
| import torch_xla.core.xla_model as xm | ||
| from typing import Any | ||
|
|
||
| autocast = torch.cuda.amp.autocast | ||
| custom_fwd = torch.cuda.amp.custom_fwd | ||
| custom_bwd = torch.cuda.amp.custom_bwd | ||
|
|
||
| class autocast(torch.amp.autocast_mode.autocast): | ||
| r""" | ||
| See :class:`torch.autocast`. | ||
| ``torch_xla.amp.autocast(device, args...)`` is equivalent to ``torch.autocast("xla", args...)`` for TPUs | ||
| ``torch.autocast("cuda", args...)`` for GPUs. | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| device, | ||
| enabled: bool = True, | ||
| dtype: torch.dtype = torch.bfloat16, | ||
| cache_enabled: bool = True): | ||
| if xm.xla_device_hw(device) == 'GPU': | ||
| super().__init__( | ||
| "cuda", | ||
| enabled=enabled, | ||
| dtype=torch.float16, | ||
| cache_enabled=cache_enabled) | ||
| elif xm.xla_device_hw(device) == 'TPU': | ||
| super().__init__( | ||
| "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) | ||
| else: | ||
| print( | ||
| 'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.' | ||
| ) | ||
|
|
||
| def __enter__(self): | ||
| return super().__enter__() | ||
|
|
||
| def __exit__(self, exc_type: Any, exc_val: Any, | ||
| exc_tb: Any): # type: ignore[override] | ||
| return super().__exit__(exc_type, exc_val, exc_tb) | ||
|
|
||
| def __call__(self, func): | ||
| return super().__call__(func) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| #include <ATen/ATen.h> | ||
| #include <ATen/NativeFunctions.h> | ||
| #include <ATen/Operators.h> | ||
| #include <ATen/autocast_mode.h> | ||
| #include <c10/core/impl/LocalDispatchKeySet.h> | ||
| #include <c10/util/intrusive_ptr.h> | ||
| #include <torch/library.h> | ||
|
|
||
| namespace at { | ||
| namespace autocast { | ||
| namespace { | ||
|
|
||
| #define KERNEL_XLA(OP, POLICY) KERNEL(c10::DeviceType::XLA, OP, POLICY) | ||
|
|
||
| #define KERNEL_XLA2(OP, OVERLOAD, POLICY) \ | ||
| KERNEL2(c10::DeviceType::XLA, OP, OVERLOAD, POLICY) | ||
|
|
||
| #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XLA( \ | ||
| REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, \ | ||
| POLICY) \ | ||
| KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(c10::DeviceType::XLA, REDISPATCH_FUNC, \ | ||
| REGISTER_NAME, REGISTER_SIGNATURE, \ | ||
| REDISPATCH_SIGNATURE, POLICY) | ||
|
|
||
| TORCH_LIBRARY_IMPL(_, AutocastXLA, m) { | ||
| m.fallback(torch::CppFunction::makeFallthrough()); | ||
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { | ||
| // lower_precision_fp cast policy | ||
| KERNEL_XLA(conv1d, lower_precision_fp) | ||
| KERNEL_XLA2(conv1d, padding, lower_precision_fp) | ||
| KERNEL_XLA(conv2d, lower_precision_fp) | ||
| KERNEL_XLA2(conv2d, padding, lower_precision_fp) | ||
| KERNEL_XLA(conv3d, lower_precision_fp) | ||
| KERNEL_XLA2(conv3d, padding, lower_precision_fp) | ||
| KERNEL_XLA(bmm, lower_precision_fp) | ||
| KERNEL_XLA(mm, lower_precision_fp) | ||
| KERNEL_XLA(baddbmm, lower_precision_fp) | ||
| KERNEL_XLA(addmm, lower_precision_fp) | ||
| KERNEL_XLA(addbmm, lower_precision_fp) | ||
| KERNEL_XLA(linear, lower_precision_fp) | ||
| KERNEL_XLA(matmul, lower_precision_fp) | ||
| KERNEL_XLA(conv_tbc, lower_precision_fp) | ||
| KERNEL_XLA(conv_transpose1d, lower_precision_fp) | ||
| KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) | ||
| KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) | ||
| KERNEL_XLA(prelu, lower_precision_fp) | ||
| KERNEL_XLA(relu, lower_precision_fp) | ||
| KERNEL_XLA(max_pool2d, lower_precision_fp) | ||
|
|
||
| // fp32 cast policy | ||
| // Commented out ops are included in the AutoCastCPU Policy, | ||
| // but not lowered. Enable if op is lowered. | ||
| KERNEL_XLA(batch_norm, fp32) | ||
| KERNEL_XLA2(log_softmax, int, fp32) | ||
| KERNEL_XLA2(log_softmax, Dimname, fp32) | ||
| KERNEL_XLA(binary_cross_entropy, fp32) | ||
| // KERNEL_XLA(grid_sampler, fp32) | ||
| // KERNEL_XLA(polar, fp32) | ||
| KERNEL_XLA(prod, fp32) | ||
| KERNEL_XLA2(prod, dim_int, fp32) | ||
| KERNEL_XLA2(prod, dim_Dimname, fp32) | ||
| // KERNEL_XLA(quantile, fp32) | ||
| // KERNEL_XLA2(quantile, scalar, fp32) | ||
| // KERNEL_XLA(nanquantile, fp32) | ||
| // KERNEL_XLA2(nanquantile, scalar, fp32) | ||
| // KERNEL_XLA(stft, fp32) | ||
| // KERNEL_XLA2(stft, center, fp32) | ||
| KERNEL_XLA(cdist, fp32) | ||
| // KERNEL_XLA(grid_sampler_2d, fp32) | ||
| // KERNEL_XLA(grid_sampler_3d, fp32) | ||
| KERNEL_XLA(trace, fp32) | ||
| // KERNEL_XLA(view_as_complex, fp32) | ||
| KERNEL_XLA(cholesky, fp32) | ||
| KERNEL_XLA(cholesky_inverse, fp32) | ||
| KERNEL_XLA(cholesky_solve, fp32) | ||
| KERNEL_XLA(inverse, fp32) | ||
| // KERNEL_XLA(lu_solve, fp32) | ||
| // KERNEL_XLA(orgqr, fp32) | ||
| // KERNEL_XLA(ormqr, fp32) | ||
| // KERNEL_XLA(pinverse, fp32) | ||
| KERNEL_XLA(reflection_pad1d, fp32) | ||
| KERNEL_XLA(reflection_pad2d, fp32) | ||
| KERNEL_XLA(replication_pad1d, fp32) | ||
| KERNEL_XLA(replication_pad2d, fp32) | ||
| KERNEL_XLA(replication_pad3d, fp32) | ||
| KERNEL_XLA(mse_loss, fp32) | ||
| KERNEL_XLA(cosine_embedding_loss, fp32) | ||
| KERNEL_XLA(nll_loss, fp32) | ||
| KERNEL_XLA(nll_loss2d, fp32) | ||
| KERNEL_XLA(hinge_embedding_loss, fp32) | ||
| // KERNEL_XLA(poisson_nll_loss, fp32) | ||
| KERNEL_XLA(smooth_l1_loss, fp32) | ||
| // KERNEL_XLA(cross_entropy_loss, fp32) | ||
| KERNEL_XLA(l1_loss, fp32) | ||
| // KERNEL_XLA(huber_loss, fp32) | ||
| KERNEL_XLA(margin_ranking_loss, fp32) | ||
| KERNEL_XLA(soft_margin_loss, fp32) | ||
| KERNEL_XLA(triplet_margin_loss, fp32) | ||
| KERNEL_XLA(multi_margin_loss, fp32) | ||
| KERNEL_XLA2(ctc_loss, IntList, fp32) | ||
| KERNEL_XLA2(ctc_loss, Tensor, fp32) | ||
| KERNEL_XLA(kl_div, fp32) | ||
| KERNEL_XLA(multilabel_margin_loss, fp32) | ||
| KERNEL_XLA(binary_cross_entropy_with_logits, fp32) | ||
| // KERNEL_XLA(fft_fft, fp32) | ||
| // KERNEL_XLA(fft_ifft, fp32) | ||
| // KERNEL_XLA(fft_fft2, fp32) | ||
| // KERNEL_XLA(fft_ifft2, fp32) | ||
| // KERNEL_XLA(fft_fftn, fp32) | ||
| // KERNEL_XLA(fft_ifftn, fp32) | ||
| // KERNEL_XLA(fft_rfft, fp32) | ||
| // KERNEL_XLA(fft_irfft, fp32) | ||
| // KERNEL_XLA(fft_rfft2, fp32) | ||
| // KERNEL_XLA(fft_irfft2, fp32) | ||
| // KERNEL_XLA(fft_rfftn, fp32) | ||
| // KERNEL_XLA(fft_irfftn, fp32) | ||
| // KERNEL_XLA(fft_hfft, fp32) | ||
| // KERNEL_XLA(fft_ihfft, fp32) | ||
| // KERNEL_XLA(linalg_cond, fp32) | ||
| // KERNEL_XLA2(linalg_cond, p_str, fp32) | ||
| // KERNEL_XLA(linalg_matrix_rank, fp32) | ||
| // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) | ||
| // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) | ||
| // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) | ||
| // KERNEL_XLA(linalg_solve, fp32) | ||
| // KERNEL_XLA(linalg_cholesky, fp32) | ||
| // KERNEL_XLA(linalg_svdvals, fp32) | ||
| // KERNEL_XLA(linalg_eigvals, fp32) | ||
| // KERNEL_XLA(linalg_eigvalsh, fp32) | ||
| // KERNEL_XLA(linalg_inv, fp32) | ||
| // KERNEL_XLA(linalg_householder_product, fp32) | ||
| // KERNEL_XLA(linalg_tensorinv, fp32) | ||
| // KERNEL_XLA(linalg_tensorsolve, fp32) | ||
| // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) | ||
| // KERNEL_XLA(geqrf, fp32) | ||
| // KERNEL_XLA(_lu_with_info, fp32) | ||
| KERNEL_XLA(qr, fp32) | ||
| KERNEL_XLA(svd, fp32) | ||
| KERNEL_XLA(triangular_solve, fp32) | ||
| KERNEL_XLA(multilabel_margin_loss_forward, fp32) | ||
| // KERNEL_XLA(linalg_qr, fp32) | ||
| // KERNEL_XLA(linalg_cholesky_ex, fp32) | ||
| KERNEL_XLA(linalg_svd, fp32) | ||
| // KERNEL_XLA(linalg_eig, fp32) | ||
| // KERNEL_XLA(linalg_eigh, fp32) | ||
| // KERNEL_XLA(linalg_lstsq, fp32) | ||
| KERNEL_XLA(linalg_inv_ex, fp32) | ||
|
|
||
| // promote | ||
| KERNEL_XLA(stack, promote) | ||
| KERNEL_XLA(cat, promote) | ||
| KERNEL_XLA(index_copy, promote) | ||
| KERNEL_XLA2(index_copy, dimname, promote) | ||
| } | ||
|
|
||
| } // namespace | ||
| } // namespace autocast | ||
| } // namespace at |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is bf16 only available on TPU? Can we use bf16 on GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bf16 is available on newer generation of GPU including A100.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So is it hardware-dependent whether bf16 is supported? XLA itself does not impose any restrictions, so why always set torch.float16 when GPU mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can relax the restriction here as a follow up, through I don't know what's a good way to tell whether current GPU handles bf16 or not. @cowanmeg FYI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can relax the dtype for GPUs. Since this is a wrapper around PyTorch's autocast class, we can use their checks to see if the GPU handles bf16.
https://github.com/pytorch/pytorch/blob/main/torch/amp/autocast_mode.py#L269