diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 01ac4b142b43..45c7a4a6de0c 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -449,7 +449,7 @@ def fa_custom_backward( if require_grad_ab: grad_ab = grads[1] - if require_grad_k or require_grad_k: + if require_grad_k or require_grad_v: payload, _ = trace_pallas( _flash_attention_bwd_dkv, q,