Skip to content

Commit f0881b5

Browse files
authored
Use f32 scratch for output so we only need to transfer output with desired dtype back to HBM. (#8924)
1 parent 2d6f57a commit f0881b5

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,9 +997,9 @@ def ragged_paged_attention(
997997
q.shape
998998
],
999999
[ # output dtype
1000-
torch.float32,
1000+
q.dtype,
10011001
])
1002-
return output[0].to(q.dtype)
1002+
return output[0]
10031003

10041004

10051005
def _multi_queries_paged_attention_nonkernel(

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def ref_ragged_paged_attention(
8181
soft_cap: float | None = None,
8282
mask_value: float | None = DEFAULT_MASK_VALUE,
8383
):
84-
check_inputs_shapes(queries, kv_pages, kv_lens, page_indices, cu_q_lens,
85-
num_seqs)
84+
validate_static_inputs(queries, kv_pages, kv_lens, page_indices, cu_q_lens,
85+
num_seqs, sliding_window, soft_cap)
8686
if mask_value is None:
8787
mask_value = DEFAULT_MASK_VALUE
8888
_, _, num_combined_kv_heads, head_dim = kv_pages.shape
@@ -124,7 +124,7 @@ def ref_ragged_paged_attention(
124124

125125

126126
# Expect to run these checkes during runtime.
127-
def validate_inputs_on_runtime(
127+
def validate_dynamic_inputs(
128128
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
129129
kv_pages: jax.
130130
Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
@@ -135,7 +135,8 @@ def validate_inputs_on_runtime(
135135
sliding_window: int | None = None,
136136
soft_cap: float | None = None,
137137
):
138-
check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs)
138+
validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens,
139+
num_seqs, sliding_window, soft_cap)
139140
max_num_batched_tokens = q.shape[0]
140141
page_size = kv_pages.shape[1]
141142
max_num_seqs, pages_per_seq = page_indices.shape
@@ -157,21 +158,19 @@ def validate_inputs_on_runtime(
157158
if q_len > kv_len:
158159
raise ValueError(
159160
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.")
160-
if sliding_window is not None and sliding_window <= 0:
161-
raise ValueError(f"{sliding_window=} must be positive.")
162-
if soft_cap is not None and soft_cap == 0.0:
163-
raise ValueError(f"{soft_cap=} must not be 0.0.")
164161

165162

166163
# Expect to run these checks during compile time.
167-
def check_inputs_shapes(
164+
def validate_static_inputs(
168165
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
169166
kv_pages: jax.
170167
Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
171168
kv_lens: jax.Array, # i32[max_num_seqs]
172169
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
173170
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
174171
num_seqs, # i32[1]
172+
sliding_window: int | None = None,
173+
soft_cap: float | None = None,
175174
):
176175
_, num_q_heads, head_dim = q.shape
177176
_, _, num_combined_kv_heads, head_dim_k = kv_pages.shape
@@ -198,6 +197,10 @@ def check_inputs_shapes(
198197
f" {cu_q_lens.dtype=}.")
199198
if num_q_heads % num_kv_heads != 0:
200199
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")
200+
if sliding_window is not None and sliding_window <= 0:
201+
raise ValueError(f"{sliding_window=} must be positive.")
202+
if soft_cap is not None and soft_cap == 0.0:
203+
raise ValueError(f"{soft_cap=} must not be 0.0.")
201204

202205

203206
def ragged_paged_attention_kernel(
@@ -218,6 +221,7 @@ def ragged_paged_attention_kernel(
218221
sems, # [2, 2]
219222
l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
220223
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
224+
acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
221225
*,
222226
sm_scale: float,
223227
sliding_window: int | None = None,
@@ -341,7 +345,7 @@ def flash_attention(
341345
v, # [num_kv_per_blk, head_dim]
342346
head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
343347
head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
344-
head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
348+
head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
345349
*,
346350
kv_blk_idx,
347351
):
@@ -362,7 +366,7 @@ def flash_attention(
362366
num_q_per_blk * num_q_heads_per_kv_head,
363367
128,
364368
)
365-
assert head_o_ref.shape == (
369+
assert head_acc_ref.shape == (
366370
num_q_per_blk,
367371
num_q_heads_per_kv_head,
368372
head_dim,
@@ -398,8 +402,8 @@ def init_scratch_ref():
398402
num_q_heads_per_kv_head,
399403
)
400404
masked_store(
401-
head_o_ref,
402-
jnp.zeros_like(head_o_ref),
405+
head_acc_ref,
406+
jnp.zeros_like(head_acc_ref),
403407
store_start,
404408
store_end,
405409
)
@@ -457,17 +461,17 @@ def broadcast_to_shape(arr, shape):
457461
return jnp.concatenate([arr for _ in range(shape[1] // arr.shape[1])],
458462
axis=1)
459463

460-
o_curr = head_o_ref[...].reshape(-1, head_dim)
464+
o_curr = head_acc_ref[...].reshape(-1, head_dim)
461465
l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
462466
beta = broadcast_to_shape(beta, qkv.shape)
463467
l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
464468
out = lax.div(
465469
l_alpha * o_curr + beta * qkv,
466470
l_next_safe,
467-
).astype(head_o_ref.dtype)
471+
)
468472
masked_store(
469-
head_o_ref,
470-
out.reshape(head_o_ref.shape),
473+
head_acc_ref,
474+
out.reshape(head_acc_ref.shape),
471475
store_start,
472476
store_end,
473477
)
@@ -513,7 +517,7 @@ def prefetch_next_kv_blk():
513517
v,
514518
l_ref.at[kv_head_idx],
515519
m_ref.at[kv_head_idx],
516-
o_ref.at[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :],
520+
acc_ref.at[:, q_head_idx:q_head_idx + num_q_heads_per_kv_head, :],
517521
kv_blk_idx=kv_blk_idx,
518522
)
519523
return kv_blk_idx + 1, next_buf_idx
@@ -535,6 +539,7 @@ def prefetch_next_kv_blk():
535539
# Reset seq_idx for next kv_heads_blk if run out of seqs!
536540
seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
537541
seq_buf_idx_ref[1] = buf_idx
542+
o_ref[...] = acc_ref[...].astype(q_ref.dtype)
538543

539544

540545
def cdiv(a, b):
@@ -629,6 +634,7 @@ def ragged_paged_attention(
629634
num_seqs: the dynamic number of sequences.
630635
sm_scale: the softmax scale which will be applied to the Q@K^T.
631636
sliding_window: the sliding window size for the attention.
637+
soft_cap: the logit soft cap for the attention.
632638
mask_value: mask value for causal mask.
633639
num_kv_pages_per_block: number of kv pages to be processed in one flash
634640
attention block in the pallas kernel.
@@ -639,7 +645,8 @@ def ragged_paged_attention(
639645
Returns:
640646
The output of the attention.
641647
"""
642-
check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs)
648+
validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens,
649+
num_seqs, sliding_window, soft_cap)
643650
if mask_value is None:
644651
mask_value = DEFAULT_MASK_VALUE
645652
num_q, num_q_heads, head_dim = q.shape
@@ -676,6 +683,10 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
676683
(num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
677684
jnp.float32,
678685
)
686+
acc_scratch = pltpu.VMEM(
687+
(num_q_per_blk, num_q_heads_per_blk, head_dim),
688+
jnp.float32,
689+
)
679690
double_buf_scratch = pltpu.VMEM(
680691
(
681692
2, # For double buffering during DMA copies.
@@ -691,6 +702,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
691702
pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers.
692703
lm_scratch, # l_ref
693704
lm_scratch, # m_ref
705+
acc_scratch,
694706
]
695707
scalar_prefetches = (
696708
kv_lens,
@@ -721,10 +733,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
721733
),
722734
vmem_limit_bytes=vmem_limit_bytes,
723735
),
724-
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32),
736+
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
725737
name="ragged_paged_attention_kernel",
726738
)
727739

728-
# TODO(jevinjiang): Use f32 acc scratch for output! So we only need
729-
# to transfer output with desired dtype back to HBM.
730-
return kernel(*scalar_prefetches, q, kv_pages).astype(q.dtype)
740+
return kernel(*scalar_prefetches, q, kv_pages)

0 commit comments

Comments
 (0)