Support fdma for models with attention bias#41882
Conversation
Introduces a cached utility to detect flash_dmattn support and minimum version on GPU setups. Validates package presence and requires >= 1.0.0 under CUDA; reports unavailable for HIP/MLU. Enables safer conditional feature use and avoids runtime errors on unsupported platforms.
Exposes the new availability helper to match other flash attention checks. Enables downstream feature gating and conditional logic based on backend presence.
Adds a decorator to conditionally skip tests that depend on Flash Dynamic Mask Attention when the feature or kernels backend is unavailable. Verifies the kernels registry for the dynamic mask attention kernel to prevent spurious failures and aligns gating with existing Flash Attention tests.
Introduces a forward wrapper for the flash dynamic mask attention path to enable the optimized backend with correct input layout and dtype handling. Ensures stable dtype under autocast, quantization, and PEFT LayerNorm casting, validates non-empty inputs, respects an is_causal override, and warns when attention outputs are requested but unsupported.
Introduces a lazy-loaded wrapper for flash dynamic mask attention to enable optional use and compatibility across versions and torch.compile modes. Processes and filters kwargs against kernel capabilities, exposing softmax scale, softcap, deterministic (env-controlled), and attention sink auxiliary when supported. Mitigates PEFT-induced fp32 casts by restoring target dtypes for queries/keys/values and bias to ensure kernel compatibility. Handles causal vs full attention, applies attention mask and bias, and normalizes outputs when kernels return tuples. Improves performance, stability, and configurability while keeping the API consistent.
Enables a new high-performance attention backend with dynamic masking and wires it into dispatch and the attention interface. Adds a capability flag and a comprehensive availability check covering package/version, device, and dtype, with NPU handling and clearer user guidance. Improves early, actionable errors and warnings. Updates help text to advertise the new attn_implementation option.
Documents availability of the flash_dmattn attention implementation alongside existing options. Improves clarity and discoverability by linking to the project.
Removes an unused import to reduce noise and satisfy linters. Simplifies the causal condition for readability without changing behavior. Adds a trailing comma in parameters to align with formatting conventions and cleaner diffs.
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for Flash Dynamic Mask Attention (FDMA) to models that use attention bias. FDMA is a more memory-efficient attention mechanism that can handle dynamic masks and attention biases.
Key changes:
- Added infrastructure for detecting and enabling Flash Dynamic Mask Attention
- Implemented FDMA forward pass with support for attention masks and biases
- Added validation and compatibility checks for FDMA usage
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| src/transformers/utils/import_utils.py | Added function to check FDMA package availability |
| src/transformers/utils/init.py | Exported FDMA availability check function |
| src/transformers/testing_utils.py | Added test decorator for FDMA requirements |
| src/transformers/modeling_utils.py | Added FDMA support flags and validation logic to model class |
| src/transformers/modeling_flash_dynamic_mask_attention_utils.py | Implemented core FDMA utilities and forward pass |
| src/transformers/integrations/flash_dynamic_mask_attention.py | Added integration layer for FDMA with model interface |
| docs/source/en/attention_interface.md | Updated documentation to mention FDMA support |
| elif is_torch_mlu_available(): | ||
| return False |
There was a problem hiding this comment.
The elif and else branches return the same value (False). These can be simplified to a single else clause.
| elif is_torch_mlu_available(): | |
| return False |
| elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]: | ||
| logger.warning_once( | ||
| "Flash Dynamic Mask Attention only supports torch.float16 and torch.bfloat16 dtypes, but" | ||
| f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," |
There was a problem hiding this comment.
Corrected spelling of 'dype' to 'dtype'.
| f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," | |
| f" the current dtype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," |
| attention_bias (`torch.Tensor`, *optional*): | ||
| The attention bias tensor of size `(batch_size, num_heads, query_len, key_len)` to add to attention scores. |
There was a problem hiding this comment.
The attention_bias parameter is documented in the docstring but is not listed in the function signature's Args section. This parameter should be included with its description in the Args list for consistency with other parameters like attention_mask.
There was a problem hiding this comment.
There is a lot of duplicated effort as it's mostly the same as our fa implementation.
The idea is to wrap this within the kernels package instead (cc @MekkCyber @drbh) and let us handle all the forwards, wrapping etc as we already support other FA implementations that way. The only modifications that we would need then, would be to make sure that the parameters match, i.e. s_aux <-> attention_bias
Edit: meaning kernels as in https://github.com/huggingface/kernel-builder
| return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case) | ||
|
|
||
|
|
||
| def require_flash_dmattn(test_case): |
There was a problem hiding this comment.
The idea is to not use a separate package but to integrate it into kernels, similar to https://huggingface.co/kernels-community/flash-attn3
There was a problem hiding this comment.
This logic could all be avoided if it was integrated into kernels.
There was a problem hiding this comment.
Same here, we would handle most of this. Maybe small modifications would be needed to support the attention_bias parameter in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py
There was a problem hiding this comment.
Same we will handle this then
|
hi @vasqu, is there and referenceable process tutorial available |
|
There is a blog at https://huggingface.co/blog/kernel-builder but I would wait for @MekkCyber @drbh as they are more familiar there and can provide proper guidance. |
MekkCyber
left a comment
There was a problem hiding this comment.
Thanks for all this work @LoserCheems ! Though as @vasqu suggested it's much better to use fdma via kernel-builder, you can just follow the blogpost, build the kernel, upload the builds to the hub, and then start using the exported torch.ops from there
|
Thank you @MekkCyber for your guidance, but I have encoumtered some issues regarding building the wheel: huggingface/kernel-builder#283 |
|
Hi @LoserCheems , any progress here? |
I'm very sorry. I haven't been able to proceed smoothly with the kernel-builder process. I think I need help... |
What does this PR do?
Fixes #41465
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@MekkCyber @drbh