Skip to content

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Dec 5, 2025

We remove the lru_cache for attention masks, because in get_attention_mask() function, and_masks(*mask_mods) will return different object id. create_attention_mask will use all parameters as cache key, and new object id will always cause cache miss.

Before the change: (llama3 debugmodel_flex_attn)
Screenshot 2025-12-09 at 1 27 45 PM

After the change:
Screenshot 2025-12-09 at 1 29 56 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 5, 2025
return (kv_idx <= q_idx) & (q_idx - kv_idx < window_size)

# Use functools.partial to bind window_size while keeping the function cacheable
sliding_window_mod = functools.partial(_sliding_window_mask, window_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed this point -- calling partial the 2nd time with the same window_size may not give cache hit, and if so this approach won't work. Could you verify?

cc @fegin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I missed it. I checked the object id when calling the partial 2nd time with the same window_size, and the id are different.

So I changed to using functiontool.lru_cache to do explicit caching, this would use window_size as cache key. I verified the ids are the same. Wdyt about this solution?

@wwwjn wwwjn changed the title Move sliding window mod to make the mod cacheable Cache sliding window mod for sliding window attention masks Dec 5, 2025
return blocked_mask_mod


@functools.lru_cache(4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried that for the only use case
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/gpt_oss/model/model.py#L364
an and_masks will be called for each iteration, whose results will have a different object id for each call (at the beginning of each layer).

In general, the caching mechanism around masks sounds not very robust. The per-iteration overhead might be fine if caching is removed. Should we remove all caching altogether? cc @fegin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, looks like we are using and_masks() for almost all models now.

If the 1and_maskreturns a new object id for each call, then the lru cache aroundcreate_attention_mask` will always miss (because the returned if from add_mask will be part of cache key)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's my guess. Could you verify? If so, I'd recommend we remove all the lru cache annotation altogether for better readability and some memory saving (because we no longer maintain a cache which is never used).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I already verified that and_mask will return different object id everytime. I will move towards removing all cache annotation.

@wwwjn wwwjn changed the title Cache sliding window mod for sliding window attention masks Remove caching for attention masks Dec 9, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one nit

@fegin
Copy link
Contributor

fegin commented Dec 9, 2025

Yes, we should remove the lru_cache. It was added before we use add_mask to hope that we can cache some masks (e.g., causal mask). This very hard to achieve in the API level. It may be better to let users decide what to cache.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants