Skip to content

Support fdma for models with attention bias#41882

Open
LoserCheems wants to merge 11 commits into
huggingface:mainfrom
LoserCheems:support-fdma
Open

Support fdma for models with attention bias#41882
LoserCheems wants to merge 11 commits into
huggingface:mainfrom
LoserCheems:support-fdma

Conversation

@LoserCheems

Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes #41465

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@MekkCyber @drbh

Introduces a cached utility to detect flash_dmattn support and minimum version on GPU setups.
Validates package presence and requires >= 1.0.0 under CUDA; reports unavailable for HIP/MLU.
Enables safer conditional feature use and avoids runtime errors on unsupported platforms.
Exposes the new availability helper to match other flash attention checks.
Enables downstream feature gating and conditional logic based on backend presence.
Adds a decorator to conditionally skip tests that depend on Flash Dynamic Mask Attention when the feature or kernels backend is unavailable.

Verifies the kernels registry for the dynamic mask attention kernel to prevent spurious failures and aligns gating with existing Flash Attention tests.
Introduces a forward wrapper for the flash dynamic mask attention path to enable the optimized backend with correct input layout and dtype handling.

Ensures stable dtype under autocast, quantization, and PEFT LayerNorm casting, validates non-empty inputs, respects an is_causal override, and warns when attention outputs are requested but unsupported.
Introduces a lazy-loaded wrapper for flash dynamic mask attention to enable optional use and compatibility across versions and torch.compile modes.

Processes and filters kwargs against kernel capabilities, exposing softmax scale, softcap, deterministic (env-controlled), and attention sink auxiliary when supported.

Mitigates PEFT-induced fp32 casts by restoring target dtypes for queries/keys/values and bias to ensure kernel compatibility.

Handles causal vs full attention, applies attention mask and bias, and normalizes outputs when kernels return tuples.

Improves performance, stability, and configurability while keeping the API consistent.
Enables a new high-performance attention backend with dynamic masking and wires it into dispatch and the attention interface.

Adds a capability flag and a comprehensive availability check covering package/version, device, and dtype, with NPU handling and clearer user guidance. Improves early, actionable errors and warnings.

Updates help text to advertise the new attn_implementation option.
Documents availability of the flash_dmattn attention implementation alongside existing options.
Improves clarity and discoverability by linking to the project.
Removes an unused import to reduce noise and satisfy linters.
Simplifies the causal condition for readability without changing behavior.
Adds a trailing comma in parameters to align with formatting conventions and cleaner diffs.
Copilot AI review requested due to automatic review settings October 27, 2025 06:54

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for Flash Dynamic Mask Attention (FDMA) to models that use attention bias. FDMA is a more memory-efficient attention mechanism that can handle dynamic masks and attention biases.

Key changes:

  • Added infrastructure for detecting and enabling Flash Dynamic Mask Attention
  • Implemented FDMA forward pass with support for attention masks and biases
  • Added validation and compatibility checks for FDMA usage

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/transformers/utils/import_utils.py Added function to check FDMA package availability
src/transformers/utils/init.py Exported FDMA availability check function
src/transformers/testing_utils.py Added test decorator for FDMA requirements
src/transformers/modeling_utils.py Added FDMA support flags and validation logic to model class
src/transformers/modeling_flash_dynamic_mask_attention_utils.py Implemented core FDMA utilities and forward pass
src/transformers/integrations/flash_dynamic_mask_attention.py Added integration layer for FDMA with model interface
docs/source/en/attention_interface.md Updated documentation to mention FDMA support

Comment on lines +883 to +884
elif is_torch_mlu_available():
return False

Copilot AI Oct 27, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The elif and else branches return the same value (False). These can be simplified to a single else clause.

Suggested change
elif is_torch_mlu_available():
return False

Copilot uses AI. Check for mistakes.
elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]:
logger.warning_once(
"Flash Dynamic Mask Attention only supports torch.float16 and torch.bfloat16 dtypes, but"
f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"

Copilot AI Oct 27, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'dype' to 'dtype'.

Suggested change
f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
f" the current dtype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"

Copilot uses AI. Check for mistakes.
Comment on lines +204 to +205
attention_bias (`torch.Tensor`, *optional*):
The attention bias tensor of size `(batch_size, num_heads, query_len, key_len)` to add to attention scores.

Copilot AI Oct 27, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention_bias parameter is documented in the docstring but is not listed in the function signature's Args section. This parameter should be included with its description in the Args list for consistency with other parameters like attention_mask.

Copilot uses AI. Check for mistakes.

@vasqu vasqu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of duplicated effort as it's mostly the same as our fa implementation.

The idea is to wrap this within the kernels package instead (cc @MekkCyber @drbh) and let us handle all the forwards, wrapping etc as we already support other FA implementations that way. The only modifications that we would need then, would be to make sure that the parameters match, i.e. s_aux <-> attention_bias

Edit: meaning kernels as in https://github.com/huggingface/kernel-builder

return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)


def require_flash_dmattn(test_case):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to not use a separate package but to integrate it into kernels, similar to https://huggingface.co/kernels-community/flash-attn3

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic could all be avoided if it was integrated into kernels.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we would handle most of this. Maybe small modifications would be needed to support the attention_bias parameter in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same we will handle this then

@LoserCheems

Copy link
Copy Markdown
Contributor Author

hi @vasqu, is there and referenceable process tutorial available

@vasqu

vasqu commented Oct 27, 2025

Copy link
Copy Markdown
Collaborator

There is a blog at https://huggingface.co/blog/kernel-builder but I would wait for @MekkCyber @drbh as they are more familiar there and can provide proper guidance.

@MekkCyber MekkCyber left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all this work @LoserCheems ! Though as @vasqu suggested it's much better to use fdma via kernel-builder, you can just follow the blogpost, build the kernel, upload the builds to the hub, and then start using the exported torch.ops from there

@LoserCheems

Copy link
Copy Markdown
Contributor Author

Thank you @MekkCyber for your guidance, but I have encoumtered some issues regarding building the wheel: huggingface/kernel-builder#283

@jordane95

Copy link
Copy Markdown
Contributor

Hi @LoserCheems , any progress here?

@LoserCheems

Copy link
Copy Markdown
Contributor Author

Hi @LoserCheems , any progress here?

I'm very sorry. I haven't been able to proceed smoothly with the kernel-builder process. I think I need help...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support s_aux in GPT-OSS

5 participants