[CK_TILE][FMHA] Enable gpt-oss sink#3490
Conversation
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR enables support for "gptoss sink" tokens in the FMHA (Fused Multi-Head Attention) implementation by adding a sink_ptr parameter throughout the codebase. The sink feature allows attention mechanisms to maintain a virtual sink token that affects softmax normalization.
Key Changes:
- Added
sink_ptrparameter to kernel argument structures and pipeline operators - Modified initialization logic to conditionally set sink values based on infinity checks
- Updated validation logic to account for sink tokens in softmax computation
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp |
Added sink_v parameter and conditional initialization logic for m/l tiles based on sink value |
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp |
Added sink_v parameter to both operator() overloads with conditional tile initialization |
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp |
Added sink_v parameter to operator() methods with conditional tile initialization |
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp |
Added sink_v parameter with split-aware conditional initialization; includes commented debug prints |
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp |
Added sink_v parameter with split-aware conditional initialization |
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp |
Added sink_v parameter to both operator() overloads with conditional tile initialization |
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp |
Added sink_ptr to Kargs, MakeKargs methods; compute per-head sink value from pointer |
include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp |
Added sink_ptr to Kargs, MakeKargs methods; compute per-head sink value from pointer |
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp |
Added sink_ptr to Kargs and all MakeKargs overloads; compute per-head sink value; removed extraneous whitespace |
include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp |
Added sink_ptr to Kargs and MakeKargsImpl methods; compute per-head sink value from pointer |
example/ck_tile/01_fmha/fmha_fwd_runner.hpp |
Added init_sink_value parameter, sink tensor allocation/initialization, and validation logic for sink tokens |
example/ck_tile/01_fmha/fmha_fwd.hpp |
Added sink_ptr field to all args structures and threaded through kargs creation |
example/ck_tile/01_fmha/example_fmha_fwd.cpp |
Added init_sink command-line argument |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ine_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…ine_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 17 out of 17 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Do we see any performance regression after adding the extra sink_ptr kernel argument? |
…_async_trload.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…ine_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
no see any performance regression in my local test. |
* Enable gptoss sink Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add gptoss sink test Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update CHANGELOG.md Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix test args error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update test_fmha_fwd.cpp * update sink test Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Revert "update sink test" This reverts commit 970b4f1. * update sink test Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update valid sink_v in splitkv pipeline Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update example_fmha_fwd.cpp * fix lse error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix clangformat error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix aiter scale error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update block_fmha_pipeline_qr_ks_vs.hpp * div scale_s for sink_value Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update fmha_fwd_runner.hpp * update sink_value with bias Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Fix typo in dropout parameter in fmha_batch_prefill_kernel * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update example_fmha_fwd.cpp * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * optimized some code Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix splitkv error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update sink reference Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update fmha_fwd_runner.hpp * Update smoke_test_fwd_sink.sh --------- Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Proposed changes
This PR enables support for "gptoss sink" tokens in the FMHA (Fused Multi-Head Attention) implementation by adding a sink_ptr parameter throughout the codebase. The sink feature allows attention mechanisms to maintain a virtual sink token that affects softmax normalization.
Key Changes:
Added sink_ptr parameter to kernel argument structures and pipeline operators
Modified initialization logic to conditionally set sink values based on infinity checks
Updated validation logic to account for sink tokens in softmax computation
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
None