-
Notifications
You must be signed in to change notification settings - Fork 629
Remove caching for attention masks #2117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torchtitan/models/attention.py
Outdated
| return (kv_idx <= q_idx) & (q_idx - kv_idx < window_size) | ||
|
|
||
| # Use functools.partial to bind window_size while keeping the function cacheable | ||
| sliding_window_mod = functools.partial(_sliding_window_mask, window_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have missed this point -- calling partial the 2nd time with the same window_size may not give cache hit, and if so this approach won't work. Could you verify?
cc @fegin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, I missed it. I checked the object id when calling the partial 2nd time with the same window_size, and the id are different.
So I changed to using functiontool.lru_cache to do explicit caching, this would use window_size as cache key. I verified the ids are the same. Wdyt about this solution?
torchtitan/models/attention.py
Outdated
| return blocked_mask_mod | ||
|
|
||
|
|
||
| @functools.lru_cache(4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm worried that for the only use case
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/gpt_oss/model/model.py#L364
an and_masks will be called for each iteration, whose results will have a different object id for each call (at the beginning of each layer).
In general, the caching mechanism around masks sounds not very robust. The per-iteration overhead might be fine if caching is removed. Should we remove all caching altogether? cc @fegin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, looks like we are using and_masks() for almost all models now.
If the 1and_maskreturns a new object id for each call, then the lru cache aroundcreate_attention_mask` will always miss (because the returned if from add_mask will be part of cache key)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's my guess. Could you verify? If so, I'd recommend we remove all the lru cache annotation altogether for better readability and some memory saving (because we no longer maintain a cache which is never used).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I already verified that and_mask will return different object id everytime. I will move towards removing all cache annotation.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, one nit
|
Yes, we should remove the lru_cache. It was added before we use add_mask to hope that we can cache some masks (e.g., causal mask). This very hard to achieve in the API level. It may be better to let users decide what to cache. |
We remove the lru_cache for attention masks, because in get_attention_mask() function,
and_masks(*mask_mods)will return different object id.create_attention_maskwill use all parameters as cache key, and new object id will always cause cache miss.Before the change: (llama3 debugmodel_flex_attn)

After the change:
