From c695bd7ac2fc4f1bf5b167af2df3d8104b9c5852 Mon Sep 17 00:00:00 2001 From: Zeev Melumian Date: Thu, 6 Feb 2025 16:04:31 +0200 Subject: [PATCH] Fix dk/dv grads not extracting on flash attention if autograd activated on dv flash attention backward function on TPU will only return keys and values gradients if the key gradients is requested by torch.autograd --- torch_xla/experimental/custom_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 01ac4b142b43..45c7a4a6de0c 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -449,7 +449,7 @@ def fa_custom_backward( if require_grad_ab: grad_ab = grads[1] - if require_grad_k or require_grad_k: + if require_grad_k or require_grad_v: payload, _ = trace_pallas( _flash_attention_bwd_dkv, q,