Reduce allocation overhead in quantized sdpa#15610
Reduce allocation overhead in quantized sdpa#15610meta-codesync[bot] merged 22 commits intogh/kimishpatel/202/basefrom
Conversation
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15610
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 4e529d1 with merge base 8af8252 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
Pull Request resolved: #15610 For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. ghstack-source-id: 321455128 @exported-using-ghexport Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/)
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
There was a problem hiding this comment.
Pull Request Overview
This PR refactors the quantized scaled dot-product attention (SDPA) implementation to reduce allocation overhead by moving the dequantization buffer allocation from inside the dequant_and_gemm function to the outer cpu_flash_attention scope. Instead of allocating a new std::vector for each dequantization operation, a pre-allocated per-thread scratch buffer is now shared across iterations.
Key changes:
- Added
buf_qdq_ptrparameter todequant_and_gemmand_qk_at_v_gemmfunctions to accept externally allocated dequantization buffers - Allocated a shared scratch buffer (
scratch_for_quant_dequant) incpu_flash_attentionwith per-thread partitioning - Removed the local
std::vector<float> dequantized_v_dataallocation fromdequant_and_gemm
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, | ||
| // query.options()); | ||
| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; | ||
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, |
There was a problem hiding this comment.
The comment says "align to 64 bytes" but kAlignment = 32 aligns to 32 elements. Since size_per_thread_qdq_vec is an element count (not byte count), and assuming accum_t is float (4 bytes), this aligns to 128 bytes (32 * 4), not 64 bytes.
Either:
- Change
kAlignmentto 16 if 64-byte alignment is desired, or - Update the comment to say "align to 32 elements" or "align to 128 bytes (for float)"
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, | |
| // Lets align size_per_thread_qdq_vec to 32 elements (128 bytes for float), for coalesced cache reads, |
| // by padding with right number of per thread elements | ||
| constexpr int64_t kAlignment = 32; | ||
| size_per_thread_qdq_vec = | ||
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); |
There was a problem hiding this comment.
The alignment calculation is incorrect. The formula (x + kAlignment - 1) & (-(kAlignment - 1)) uses the wrong mask.
For aligning to a power-of-2 boundary, the correct formula is:
(size_per_thread_qdq_vec + kAlignment - 1) & (-kAlignment)or equivalently:
(size_per_thread_qdq_vec + kAlignment - 1) & ~(kAlignment - 1)The current code uses -(kAlignment - 1) which equals -31 = 0xFFFFFFE1, but the correct mask should be -32 = 0xFFFFFFE0 to properly zero out the bottom 5 bits.
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); | |
| (size_per_thread_qdq_vec + kAlignment - 1) & -kAlignment; |
There was a problem hiding this comment.
@copilot are you sure? Please double check again
| // at::Tensor buf_reduced = at::empty( | ||
| // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, | ||
| // query.options()); | ||
| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; |
There was a problem hiding this comment.
The buffer size calculation appears to be larger than necessary. The dequantize operation needs kvBlockSize * headSize elements (at most kvSplitSize * headSize), but this allocates qSplitSize * kvSplitSize * headSize. The extra qSplitSize factor seems unnecessary and wastes memory per thread.
Consider changing to:
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; | |
| int64_t size_per_thread_qdq_vec = kvSplitSize * headSize; |
|
@mergennachin I've opened a new pull request, #15852, to work on those changes. Once the pull request is ready, I'll request review from you. |
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
Pull Request resolved: #15610 For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. ghstack-source-id: 324720932 @exported-using-ghexport Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/)
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
37078bb
into
gh/kimishpatel/202/base
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15610 by @kimishpatel ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/kimishpatel/202/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/202/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/202/orig Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) @diff-train-skip-merge Co-authored-by: Kimish Patel <kimishpatel@fb.com>
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#15610 by @kimishpatel ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/kimishpatel/202/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/202/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/kimishpatel/202/orig Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) @diff-train-skip-merge Co-authored-by: Kimish Patel <kimishpatel@fb.com>
Stack from ghstack (oldest at bottom):
For small models dequantizing portions of v cache causes extra alloc overhead.
Probably a better way to handle this is to dequantize entire v cache outside the model
There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help.
Differential Revision: D85532077