Skip to content

Fix triton cross-entropy for large vocab sizes, support tensor-parallel#466

Draft
jlamypoirier wants to merge 5 commits intojlp_entropy_loss_tweaksfrom
jlp_triton_loss
Draft

Fix triton cross-entropy for large vocab sizes, support tensor-parallel#466
jlamypoirier wants to merge 5 commits intojlp_entropy_loss_tweaksfrom
jlp_triton_loss

Conversation

@jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Jan 31, 2026

✨ Description

Add looped and TP implementations of cross-entropy loss. Turns out the 64K vocab limitation is gone, but going higher makes the kernels way slower, so looped is still better. (Above 32K actually)

Test benchmark (8K tokens, cuda time + est. memory usage):

# Single GPU, vocab 10K
fused 0.348 ms 492.078 MB
triton 0.169 ms 163.873 MB

# Single GPU, vocab 100K
fused 4.241 ms 4915.233 MB
triton 1.709 ms 1638.433 MB

# 2 GPUs, vocab 10K
fused 1.108 ms 655.606 MB
triton 0.198 ms 82.084 MB

# 2 GPUs, vocab 100K
fused 9.569 ms 6553.846 MB
triton 0.996 ms 819.364 MB

@jlamypoirier jlamypoirier changed the title Fix triton cross-entropy for large vocab sizes Fix triton cross-entropy for large vocab sizes, support tensor-parallel Feb 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant