Skip to content

Use TORCH_LIBRARY*, Make Python ABI Stable#2009

Draft
crcrpar wants to merge 4 commits into
NVIDIA:masterfrom
crcrpar:crcrpar/stable-abi-dispatcher-extensions
Draft

Use TORCH_LIBRARY*, Make Python ABI Stable#2009
crcrpar wants to merge 4 commits into
NVIDIA:masterfrom
crcrpar:crcrpar/stable-abi-dispatcher-extensions

Conversation

@crcrpar
Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar commented May 27, 2026

Summary

This PR moves APEX custom CUDA extensions toward Python limited ABI compatibility by replacing Python C API / pybind module entry points with dispatcher-registered operators and lightweight Python shims.

The extension libraries now register their callable surface through TORCH_LIBRARY* and are loaded from private apex._extensions modules. Existing Python call sites route through torch.ops.apex, so the compiled extension modules no longer need to expose top-level Python modules such as amp_C.

The dispatcher-facing C++ functions now use dispatcher-native scalar types where practical: Python int maps to int64_t, and Python float maps to double. Legacy CUDA helper calls still receive their existing int / float parameters, with narrowing kept at that lower boundary.

Details

  • Add shared custom-op loading and scalar/list helpers under apex._custom_ops.
  • Add private Python shims under apex._extensions.
  • Convert extension frontends from PYBIND11_MODULE to dispatcher registration.
  • Update Python imports to use private shims instead of top-level extension modules.
  • Remove simple scalar-only *_dispatch wrappers in transducer, xentropy, focal loss, Megatron softmax, fused LAMB, and distributed optimizer frontends.
  • Leave wrappers that adapt dispatcher const Tensor& arguments to mutable legacy Tensor& internals in place.

Authored by gpt-5.5

crcrpar added 3 commits May 28, 2026 00:16
Convert the Python-facing custom extension surface to private dispatcher-backed shims under apex._extensions. The shims load the compiled libraries with torch.ops.load_library, call torch.ops.apex registrations directly, and normalize scalar and tensor-list arguments where the dispatcher schemas are stricter than the previous pybind entry points.

Move the generated Python module names out of the repository root and stop packaging top-level compatibility modules such as amp_C, fused_adam_cuda, and fused_layer_norm_cuda. Internal APEX imports and affected tests now import through apex._extensions instead, which keeps the package root clean and makes the compatibility break explicit.

Replace converted C++ extension frontends with TORCH_LIBRARY dispatcher registrations and build them with py_limited_api where possible. The remaining non-stable-ABI surfaces are left out of this conversion because they still need Python object bindings or setup-time helper behavior.

Preserve test-observed behavior while changing the binding layer: fp16 clip_grad falls back to PyTorch's clipping semantics, FusedDense initializes parameters like nn.Linear, and bf16 FusedDense uses the PyTorch matmul+bias path to avoid the fused cublasLt bf16 mismatch.

Authored with codex gpt-5.5 xhigh

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Register the transducer joint and loss checked entry points directly with the dispatcher instead of routing through _dispatch adapters that immediately narrowed int64_t and double arguments back to int and float.

Keep the dispatcher-facing C++ signatures aligned with schema binding semantics, where schema int arrives as int64_t and schema float arrives as double. The only remaining narrowing is at CUDA kernel call sites that still require float values.
Python dispatcher schemas canonicalize int and float arguments to int64_t and double before invoking C++ kernels. Several APEX custom op frontends still registered small *_dispatch helpers only to narrow those values to int or float and then call the checked C++ entry point.

Register the checked entry points directly for xentropy, focal loss, Megatron scaled softmax variants, fused LAMB, and distributed optimizer frontends. Keep the explicit narrowing at the existing CUDA helper boundary so legacy kernels continue to receive the same scalar types, while the dispatcher-facing code uses the native scalar widths.

Leave wrappers that adapt const dispatcher tensors to mutable internal tensor references in place; those are a separate legacy API boundary rather than the scalar precision cleanup handled here.
@crcrpar crcrpar marked this pull request as draft May 27, 2026 16:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR migrates APEX CUDA extension entry points from Python/pybind modules to dispatcher-registered torch.ops.apex operators with private Python shims under apex._extensions, moving the project toward Python limited ABI compatibility.

Changes:

  • Adds shared custom-op loading helpers and many private extension shim modules.
  • Converts multiple C++/CUDA frontends from PYBIND11_MODULE to TORCH_LIBRARY* registration.
  • Updates Python import sites and tests to use apex._extensions, with a few functional adjustments for fused dense and fp16 gradient clipping.

Reviewed changes

Copilot reviewed 139 out of 139 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
apex/_custom_ops.py Adds shared library discovery/loading and scalar/list conversion helpers.
apex/_extensions/init.py Introduces private extension shim package.
apex/_extensions/amp_C.py Adds dispatcher shim for AMP multi-tensor ops.
apex/_extensions/apex_C.py Adds Python flatten/unflatten shim.
apex/_extensions/bnp.py Adds dispatcher shim for group batch norm ops.
apex/_extensions/cudnn_gbn_lib.py Adds dispatcher shim for cuDNN group batch norm.
apex/_extensions/distributed_adam_cuda.py Adds dispatcher shim for distributed Adam kernels.
apex/_extensions/distributed_lamb_cuda.py Adds dispatcher shim for distributed LAMB kernels.
apex/_extensions/fast_bottleneck.py Adds dispatcher shim for bottleneck kernels.
apex/_extensions/fast_layer_norm.py Adds dispatcher shim for fast layer norm.
apex/_extensions/fmhalib.py Adds dispatcher shim for FMHA kernels.
apex/_extensions/focal_loss_cuda.py Adds dispatcher shim for focal loss.
apex/_extensions/fused_adam_cuda.py Adds dispatcher shim for fused Adam.
apex/_extensions/fused_conv_bias_relu.py Adds dispatcher shim for fused conv/bias/ReLU.
apex/_extensions/fused_dense_cuda.py Adds dispatcher shim for fused dense.
apex/_extensions/fused_index_mul_2d.py Adds dispatcher shim for index-mul kernels.
apex/_extensions/fused_lamb_cuda.py Adds dispatcher shim for fused LAMB.
apex/_extensions/fused_layer_norm_cuda.py Adds dispatcher shim for fused layer norm.
apex/_extensions/fused_rotary_positional_embedding.py Adds dispatcher shim for fused RoPE.
apex/_extensions/fused_weight_gradient_mlp_cuda.py Adds dispatcher shim for fused weight-gradient MLP.
apex/_extensions/generic_scaled_masked_softmax_cuda.py Adds dispatcher shim for generic scaled masked softmax.
apex/_extensions/group_norm_cuda.py Adds dispatcher shim for group norm.
apex/_extensions/group_norm_v2_cuda.py Adds dispatcher shim for group norm v2.
apex/_extensions/mlp_cuda.py Adds dispatcher shim for MLP kernels.
apex/_extensions/nccl_p2p_cuda.py Adds dispatcher shim for NCCL P2P utilities.
apex/_extensions/peer_memory_cuda.py Adds dispatcher shim for peer memory utilities.
apex/_extensions/permutation_search_cuda.py Adds dispatcher shim for permutation search kernels.
apex/_extensions/scaled_masked_softmax_cuda.py Adds dispatcher shim for scaled masked softmax.
apex/_extensions/scaled_softmax_cuda.py Adds dispatcher shim for scaled softmax.
apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py Adds dispatcher shim for upper-triangular masked softmax.
apex/_extensions/syncbn.py Adds dispatcher shim for sync batch norm.
apex/_extensions/transducer_joint_cuda.py Adds dispatcher shim for transducer joint kernels.
apex/_extensions/transducer_loss_cuda.py Adds dispatcher shim for transducer loss kernels.
apex/_extensions/xentropy_cuda.py Adds dispatcher shim for xentropy kernels and version.
apex/contrib/bottleneck/bottleneck.py Updates extension imports to private shims.
apex/contrib/bottleneck/halo_exchangers.py Updates NCCL/peer memory imports to private shims.
apex/contrib/clip_grad/clip_grad.py Updates AMP import and adds fp16 fallback path.
apex/contrib/conv_bias_relu/conv_bias_relu.py Updates fused conv import to private shim.
apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp Registers fused conv operators through dispatcher.
apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp Registers cuDNN GBN operators through dispatcher.
apex/contrib/csrc/cudnn_gbn/norm_sample.cpp Replaces torch extension dependencies with ATen APIs.
apex/contrib/csrc/fmha/fmha_api.cpp Registers FMHA operators through dispatcher.
apex/contrib/csrc/fmha/src/fmha_fill.cu Replaces torch extension include with ATen.
apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp Registers focal loss operators through dispatcher.
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp Registers index-mul operators through dispatcher.
apex/contrib/csrc/layer_norm/ln_api.cpp Registers fast layer norm operators through dispatcher.
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh Makes CUDA profiler include optional.
apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp Registers NCCL P2P operators through dispatcher.
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh Replaces torch extension include with ATen/vector headers.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp Registers fused Adam operators through dispatcher.
apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp Registers fused LAMB operator through dispatcher.
apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp Registers distributed Adam operators through dispatcher.
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp Registers distributed LAMB operators through dispatcher.
apex/contrib/csrc/peer_memory/peer_memory.cpp Registers peer memory operators through dispatcher.
apex/contrib/csrc/peer_memory/peer_memory_cuda.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh Replaces torch extension include with ATen/vector headers.
apex/contrib/csrc/transducer/transducer_joint.cpp Registers transducer joint operators through dispatcher.
apex/contrib/csrc/transducer/transducer_joint_kernel.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/transducer/transducer_loss.cpp Registers transducer loss operators through dispatcher.
apex/contrib/csrc/transducer/transducer_loss_kernel.cu Replaces torch tensor APIs with ATen equivalents.
apex/contrib/csrc/xentropy/interface.cpp Registers xentropy operators and version through dispatcher.
apex/contrib/cudnn_gbn/batch_norm.py Updates extension imports to private shims.
apex/contrib/fmha/fmha.py Updates FMHA import to private shim.
apex/contrib/focal_loss/init.py Updates focal loss import to private shim.
apex/contrib/focal_loss/focal_loss.py Updates focal loss import to private shim.
apex/contrib/group_norm/group_norm.py Updates group norm imports to private shims.
apex/contrib/groupbn/init.py Updates BNP import to private shim.
apex/contrib/groupbn/batch_norm.py Updates BNP import to private shim.
apex/contrib/index_mul_2d/index_mul_2d.py Updates index-mul import to private shim.
apex/contrib/layer_norm/layer_norm.py Updates fast layer norm import to private shim.
apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py Updates fast multihead attention import to private shim.
apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py Updates fast multihead attention import to private shim.
apex/contrib/multihead_attn/fast_self_multihead_attn_func.py Updates fast multihead attention import to private shim.
apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py Updates fast multihead attention import to private shim.
apex/contrib/multihead_attn/mask_softmax_dropout_func.py Updates fast multihead attention import to private shim.
apex/contrib/optimizers/distributed_fused_adam.py Updates distributed optimizer extension imports.
apex/contrib/optimizers/distributed_fused_lamb.py Updates distributed optimizer extension imports.
apex/contrib/optimizers/fp16_optimizer.py Updates AMP import to private shim.
apex/contrib/optimizers/fused_adam.py Updates fused Adam import to private shim.
apex/contrib/optimizers/fused_lamb.py Updates AMP and fused LAMB imports to private shims.
apex/contrib/optimizers/fused_sgd.py Updates AMP import to private shim.
apex/contrib/peer_memory/peer_halo_exchanger_1d.py Updates peer memory import to private shim.
apex/contrib/peer_memory/peer_memory.py Updates peer memory import to private shim.
apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py Updates permutation search kernel imports.
apex/contrib/test/layer_norm/test_fast_layer_norm.py Updates test import to private shim.
apex/contrib/transducer/transducer.py Updates transducer imports to private shims.
apex/contrib/xentropy/softmax_xentropy.py Updates xentropy import to private shim.
apex/fused_dense/fused_dense.py Updates fused dense import, initializes parameters, and adds BF16 fallback.
apex/mlp/mlp.py Updates MLP import to private shim.
apex/multi_tensor_apply/multi_tensor_apply.py Updates AMP import to private shim.
apex/normalization/fused_layer_norm.py Updates dynamic fused layer norm imports to private shim path.
apex/optimizers/fused_adagrad.py Updates AMP import to private shim.
apex/optimizers/fused_adam.py Updates AMP import to private shim.
apex/optimizers/fused_lamb.py Updates AMP import to private shim.
apex/optimizers/fused_mixed_precision_lamb.py Updates AMP import to private shim.
apex/optimizers/fused_novograd.py Updates AMP import to private shim.
apex/optimizers/fused_sgd.py Updates AMP import to private shim.
csrc/fused_dense.cpp Registers fused dense operators through dispatcher.
csrc/fused_dense_cuda.cu Removes torch/torch include.
csrc/layer_norm_cuda.cpp Registers fused layer norm operators through dispatcher.
csrc/megatron/fused_rotary_positional_embedding.h Replaces torch extension include with CUDA exception header.
csrc/megatron/fused_rotary_positional_embedding_cuda.cu Replaces torch tensor APIs with ATen equivalents.
csrc/megatron/fused_weight_gradient_dense.cpp Registers fused weight-gradient operators through dispatcher.
csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu Replaces torch extension include with ATen CUDA exceptions.
csrc/megatron/fused_weight_gradient_dense_cuda.cu Replaces torch extension include with ATen CUDA exceptions.
csrc/megatron/generic_scaled_masked_softmax.cpp Registers generic scaled masked softmax operators.
csrc/megatron/generic_scaled_masked_softmax_cuda.cu Replaces torch tensor APIs with ATen equivalents.
csrc/megatron/scaled_masked_softmax.cpp Registers scaled masked softmax operators.
csrc/megatron/scaled_masked_softmax_cuda.cu Replaces torch tensor APIs with ATen equivalents.
csrc/megatron/scaled_softmax.cpp Registers scaled softmax operators.
csrc/megatron/scaled_softmax_cuda.cu Replaces torch tensor APIs with ATen equivalents.
csrc/megatron/scaled_upper_triang_masked_softmax.cpp Registers upper-triangular masked softmax operators.
csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu Replaces torch tensor APIs with ATen equivalents.
csrc/mlp.cpp Registers MLP operators through dispatcher.
csrc/mlp_cuda.cu Removes torch/torch include.
csrc/syncbn.cpp Registers sync batch norm operators through dispatcher.
tests/L0/run_optimizers/test_lamb.py Updates test AMP import to private shim.
tests/distributed/synced_batchnorm/single_gpu_unit_test.py Updates syncbn test import to private shim.
tests/distributed/synced_batchnorm/test_groups.py Updates syncbn test import to private shim.
tests/distributed/synced_batchnorm/two_gpu_unit_test.py Updates syncbn test import to private shim.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

except ImportError:
try:
from . import permutation_search_cuda as permutation_search_cuda_kernels
from apex._extensions import permutation_search_cuda as permutation_search_cuda_kernels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants