Use TORCH_LIBRARY*, Make Python ABI Stable#2009
Draft
crcrpar wants to merge 4 commits into
Draft
Conversation
Convert the Python-facing custom extension surface to private dispatcher-backed shims under apex._extensions. The shims load the compiled libraries with torch.ops.load_library, call torch.ops.apex registrations directly, and normalize scalar and tensor-list arguments where the dispatcher schemas are stricter than the previous pybind entry points. Move the generated Python module names out of the repository root and stop packaging top-level compatibility modules such as amp_C, fused_adam_cuda, and fused_layer_norm_cuda. Internal APEX imports and affected tests now import through apex._extensions instead, which keeps the package root clean and makes the compatibility break explicit. Replace converted C++ extension frontends with TORCH_LIBRARY dispatcher registrations and build them with py_limited_api where possible. The remaining non-stable-ABI surfaces are left out of this conversion because they still need Python object bindings or setup-time helper behavior. Preserve test-observed behavior while changing the binding layer: fp16 clip_grad falls back to PyTorch's clipping semantics, FusedDense initializes parameters like nn.Linear, and bf16 FusedDense uses the PyTorch matmul+bias path to avoid the fused cublasLt bf16 mismatch. Authored with codex gpt-5.5 xhigh Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Register the transducer joint and loss checked entry points directly with the dispatcher instead of routing through _dispatch adapters that immediately narrowed int64_t and double arguments back to int and float. Keep the dispatcher-facing C++ signatures aligned with schema binding semantics, where schema int arrives as int64_t and schema float arrives as double. The only remaining narrowing is at CUDA kernel call sites that still require float values.
Python dispatcher schemas canonicalize int and float arguments to int64_t and double before invoking C++ kernels. Several APEX custom op frontends still registered small *_dispatch helpers only to narrow those values to int or float and then call the checked C++ entry point. Register the checked entry points directly for xentropy, focal loss, Megatron scaled softmax variants, fused LAMB, and distributed optimizer frontends. Keep the explicit narrowing at the existing CUDA helper boundary so legacy kernels continue to receive the same scalar types, while the dispatcher-facing code uses the native scalar widths. Leave wrappers that adapt const dispatcher tensors to mutable internal tensor references in place; those are a separate legacy API boundary rather than the scalar precision cleanup handled here.
for more information, see https://pre-commit.ci
Contributor
There was a problem hiding this comment.
Pull request overview
This PR migrates APEX CUDA extension entry points from Python/pybind modules to dispatcher-registered torch.ops.apex operators with private Python shims under apex._extensions, moving the project toward Python limited ABI compatibility.
Changes:
- Adds shared custom-op loading helpers and many private extension shim modules.
- Converts multiple C++/CUDA frontends from
PYBIND11_MODULEtoTORCH_LIBRARY*registration. - Updates Python import sites and tests to use
apex._extensions, with a few functional adjustments for fused dense and fp16 gradient clipping.
Reviewed changes
Copilot reviewed 139 out of 139 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| apex/_custom_ops.py | Adds shared library discovery/loading and scalar/list conversion helpers. |
| apex/_extensions/init.py | Introduces private extension shim package. |
| apex/_extensions/amp_C.py | Adds dispatcher shim for AMP multi-tensor ops. |
| apex/_extensions/apex_C.py | Adds Python flatten/unflatten shim. |
| apex/_extensions/bnp.py | Adds dispatcher shim for group batch norm ops. |
| apex/_extensions/cudnn_gbn_lib.py | Adds dispatcher shim for cuDNN group batch norm. |
| apex/_extensions/distributed_adam_cuda.py | Adds dispatcher shim for distributed Adam kernels. |
| apex/_extensions/distributed_lamb_cuda.py | Adds dispatcher shim for distributed LAMB kernels. |
| apex/_extensions/fast_bottleneck.py | Adds dispatcher shim for bottleneck kernels. |
| apex/_extensions/fast_layer_norm.py | Adds dispatcher shim for fast layer norm. |
| apex/_extensions/fmhalib.py | Adds dispatcher shim for FMHA kernels. |
| apex/_extensions/focal_loss_cuda.py | Adds dispatcher shim for focal loss. |
| apex/_extensions/fused_adam_cuda.py | Adds dispatcher shim for fused Adam. |
| apex/_extensions/fused_conv_bias_relu.py | Adds dispatcher shim for fused conv/bias/ReLU. |
| apex/_extensions/fused_dense_cuda.py | Adds dispatcher shim for fused dense. |
| apex/_extensions/fused_index_mul_2d.py | Adds dispatcher shim for index-mul kernels. |
| apex/_extensions/fused_lamb_cuda.py | Adds dispatcher shim for fused LAMB. |
| apex/_extensions/fused_layer_norm_cuda.py | Adds dispatcher shim for fused layer norm. |
| apex/_extensions/fused_rotary_positional_embedding.py | Adds dispatcher shim for fused RoPE. |
| apex/_extensions/fused_weight_gradient_mlp_cuda.py | Adds dispatcher shim for fused weight-gradient MLP. |
| apex/_extensions/generic_scaled_masked_softmax_cuda.py | Adds dispatcher shim for generic scaled masked softmax. |
| apex/_extensions/group_norm_cuda.py | Adds dispatcher shim for group norm. |
| apex/_extensions/group_norm_v2_cuda.py | Adds dispatcher shim for group norm v2. |
| apex/_extensions/mlp_cuda.py | Adds dispatcher shim for MLP kernels. |
| apex/_extensions/nccl_p2p_cuda.py | Adds dispatcher shim for NCCL P2P utilities. |
| apex/_extensions/peer_memory_cuda.py | Adds dispatcher shim for peer memory utilities. |
| apex/_extensions/permutation_search_cuda.py | Adds dispatcher shim for permutation search kernels. |
| apex/_extensions/scaled_masked_softmax_cuda.py | Adds dispatcher shim for scaled masked softmax. |
| apex/_extensions/scaled_softmax_cuda.py | Adds dispatcher shim for scaled softmax. |
| apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py | Adds dispatcher shim for upper-triangular masked softmax. |
| apex/_extensions/syncbn.py | Adds dispatcher shim for sync batch norm. |
| apex/_extensions/transducer_joint_cuda.py | Adds dispatcher shim for transducer joint kernels. |
| apex/_extensions/transducer_loss_cuda.py | Adds dispatcher shim for transducer loss kernels. |
| apex/_extensions/xentropy_cuda.py | Adds dispatcher shim for xentropy kernels and version. |
| apex/contrib/bottleneck/bottleneck.py | Updates extension imports to private shims. |
| apex/contrib/bottleneck/halo_exchangers.py | Updates NCCL/peer memory imports to private shims. |
| apex/contrib/clip_grad/clip_grad.py | Updates AMP import and adds fp16 fallback path. |
| apex/contrib/conv_bias_relu/conv_bias_relu.py | Updates fused conv import to private shim. |
| apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp | Registers fused conv operators through dispatcher. |
| apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp | Registers cuDNN GBN operators through dispatcher. |
| apex/contrib/csrc/cudnn_gbn/norm_sample.cpp | Replaces torch extension dependencies with ATen APIs. |
| apex/contrib/csrc/fmha/fmha_api.cpp | Registers FMHA operators through dispatcher. |
| apex/contrib/csrc/fmha/src/fmha_fill.cu | Replaces torch extension include with ATen. |
| apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp | Registers focal loss operators through dispatcher. |
| apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp | Registers index-mul operators through dispatcher. |
| apex/contrib/csrc/layer_norm/ln_api.cpp | Registers fast layer norm operators through dispatcher. |
| apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh | Makes CUDA profiler include optional. |
| apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp | Registers NCCL P2P operators through dispatcher. |
| apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh | Replaces torch extension include with ATen/vector headers. |
| apex/contrib/csrc/optimizers/fused_adam_cuda.cpp | Registers fused Adam operators through dispatcher. |
| apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp | Registers fused LAMB operator through dispatcher. |
| apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp | Registers distributed Adam operators through dispatcher. |
| apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp | Registers distributed LAMB operators through dispatcher. |
| apex/contrib/csrc/peer_memory/peer_memory.cpp | Registers peer memory operators through dispatcher. |
| apex/contrib/csrc/peer_memory/peer_memory_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh | Replaces torch extension include with ATen/vector headers. |
| apex/contrib/csrc/transducer/transducer_joint.cpp | Registers transducer joint operators through dispatcher. |
| apex/contrib/csrc/transducer/transducer_joint_kernel.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/transducer/transducer_loss.cpp | Registers transducer loss operators through dispatcher. |
| apex/contrib/csrc/transducer/transducer_loss_kernel.cu | Replaces torch tensor APIs with ATen equivalents. |
| apex/contrib/csrc/xentropy/interface.cpp | Registers xentropy operators and version through dispatcher. |
| apex/contrib/cudnn_gbn/batch_norm.py | Updates extension imports to private shims. |
| apex/contrib/fmha/fmha.py | Updates FMHA import to private shim. |
| apex/contrib/focal_loss/init.py | Updates focal loss import to private shim. |
| apex/contrib/focal_loss/focal_loss.py | Updates focal loss import to private shim. |
| apex/contrib/group_norm/group_norm.py | Updates group norm imports to private shims. |
| apex/contrib/groupbn/init.py | Updates BNP import to private shim. |
| apex/contrib/groupbn/batch_norm.py | Updates BNP import to private shim. |
| apex/contrib/index_mul_2d/index_mul_2d.py | Updates index-mul import to private shim. |
| apex/contrib/layer_norm/layer_norm.py | Updates fast layer norm import to private shim. |
| apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py | Updates fast multihead attention import to private shim. |
| apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py | Updates fast multihead attention import to private shim. |
| apex/contrib/multihead_attn/fast_self_multihead_attn_func.py | Updates fast multihead attention import to private shim. |
| apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py | Updates fast multihead attention import to private shim. |
| apex/contrib/multihead_attn/mask_softmax_dropout_func.py | Updates fast multihead attention import to private shim. |
| apex/contrib/optimizers/distributed_fused_adam.py | Updates distributed optimizer extension imports. |
| apex/contrib/optimizers/distributed_fused_lamb.py | Updates distributed optimizer extension imports. |
| apex/contrib/optimizers/fp16_optimizer.py | Updates AMP import to private shim. |
| apex/contrib/optimizers/fused_adam.py | Updates fused Adam import to private shim. |
| apex/contrib/optimizers/fused_lamb.py | Updates AMP and fused LAMB imports to private shims. |
| apex/contrib/optimizers/fused_sgd.py | Updates AMP import to private shim. |
| apex/contrib/peer_memory/peer_halo_exchanger_1d.py | Updates peer memory import to private shim. |
| apex/contrib/peer_memory/peer_memory.py | Updates peer memory import to private shim. |
| apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py | Updates permutation search kernel imports. |
| apex/contrib/test/layer_norm/test_fast_layer_norm.py | Updates test import to private shim. |
| apex/contrib/transducer/transducer.py | Updates transducer imports to private shims. |
| apex/contrib/xentropy/softmax_xentropy.py | Updates xentropy import to private shim. |
| apex/fused_dense/fused_dense.py | Updates fused dense import, initializes parameters, and adds BF16 fallback. |
| apex/mlp/mlp.py | Updates MLP import to private shim. |
| apex/multi_tensor_apply/multi_tensor_apply.py | Updates AMP import to private shim. |
| apex/normalization/fused_layer_norm.py | Updates dynamic fused layer norm imports to private shim path. |
| apex/optimizers/fused_adagrad.py | Updates AMP import to private shim. |
| apex/optimizers/fused_adam.py | Updates AMP import to private shim. |
| apex/optimizers/fused_lamb.py | Updates AMP import to private shim. |
| apex/optimizers/fused_mixed_precision_lamb.py | Updates AMP import to private shim. |
| apex/optimizers/fused_novograd.py | Updates AMP import to private shim. |
| apex/optimizers/fused_sgd.py | Updates AMP import to private shim. |
| csrc/fused_dense.cpp | Registers fused dense operators through dispatcher. |
| csrc/fused_dense_cuda.cu | Removes torch/torch include. |
| csrc/layer_norm_cuda.cpp | Registers fused layer norm operators through dispatcher. |
| csrc/megatron/fused_rotary_positional_embedding.h | Replaces torch extension include with CUDA exception header. |
| csrc/megatron/fused_rotary_positional_embedding_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| csrc/megatron/fused_weight_gradient_dense.cpp | Registers fused weight-gradient operators through dispatcher. |
| csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu | Replaces torch extension include with ATen CUDA exceptions. |
| csrc/megatron/fused_weight_gradient_dense_cuda.cu | Replaces torch extension include with ATen CUDA exceptions. |
| csrc/megatron/generic_scaled_masked_softmax.cpp | Registers generic scaled masked softmax operators. |
| csrc/megatron/generic_scaled_masked_softmax_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| csrc/megatron/scaled_masked_softmax.cpp | Registers scaled masked softmax operators. |
| csrc/megatron/scaled_masked_softmax_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| csrc/megatron/scaled_softmax.cpp | Registers scaled softmax operators. |
| csrc/megatron/scaled_softmax_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| csrc/megatron/scaled_upper_triang_masked_softmax.cpp | Registers upper-triangular masked softmax operators. |
| csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu | Replaces torch tensor APIs with ATen equivalents. |
| csrc/mlp.cpp | Registers MLP operators through dispatcher. |
| csrc/mlp_cuda.cu | Removes torch/torch include. |
| csrc/syncbn.cpp | Registers sync batch norm operators through dispatcher. |
| tests/L0/run_optimizers/test_lamb.py | Updates test AMP import to private shim. |
| tests/distributed/synced_batchnorm/single_gpu_unit_test.py | Updates syncbn test import to private shim. |
| tests/distributed/synced_batchnorm/test_groups.py | Updates syncbn test import to private shim. |
| tests/distributed/synced_batchnorm/two_gpu_unit_test.py | Updates syncbn test import to private shim. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| except ImportError: | ||
| try: | ||
| from . import permutation_search_cuda as permutation_search_cuda_kernels | ||
| from apex._extensions import permutation_search_cuda as permutation_search_cuda_kernels |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR moves APEX custom CUDA extensions toward Python limited ABI compatibility by replacing Python C API / pybind module entry points with dispatcher-registered operators and lightweight Python shims.
The extension libraries now register their callable surface through
TORCH_LIBRARY*and are loaded from privateapex._extensionsmodules. Existing Python call sites route throughtorch.ops.apex, so the compiled extension modules no longer need to expose top-level Python modules such asamp_C.The dispatcher-facing C++ functions now use dispatcher-native scalar types where practical: Python
intmaps toint64_t, and Pythonfloatmaps todouble. Legacy CUDA helper calls still receive their existingint/floatparameters, with narrowing kept at that lower boundary.Details
apex._custom_ops.apex._extensions.PYBIND11_MODULEto dispatcher registration.*_dispatchwrappers in transducer, xentropy, focal loss, Megatron softmax, fused LAMB, and distributed optimizer frontends.const Tensor&arguments to mutable legacyTensor&internals in place.Authored by gpt-5.5