Add zero dim tensor check when using flash_attention#38280
Conversation
6afac40 to
a9fdcf4
Compare
|
cc @ArthurZucker for FA2 |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks, I'd rather we write 3 explicit checks as only query, key and value need this check. Maybe even just testing key?
Do you mean we don't need write a function to do this check, just call explicit check directly? if so, I agree with you, I can write 3 explicit checks for query, key and value, or only for key, and give users info which shape is wrong. |
|
Yeah, I mean only checking q should be enough as well no? |
3d36772 to
8365960
Compare
Signed-off-by: ranzhejiang <zhejiang.ran@intel.com>
Yes, I agree with you and have change my code, thanks for review |
Signed-off-by: ranzhejiang <zhejiang.ran@intel.com>
|
@ArthurZucker Hi ArthurZucker, I have changed my code following your advice, can you help review this PR ? Thanks |
|
sorry I was out for holidays |
The cuda or triton kernel can not support this case: dimensions of size is 0, but traditional SDPA is implemented using PyTorch's tensor operations, which have robust support for tensors with dimensions of size 0. We need to add this check and error tips for developers when using flash_attention. Related issue is in deepspeedai/DeepSpeed#7275