@@ -81,8 +81,8 @@ def ref_ragged_paged_attention(
8181 soft_cap : float | None = None ,
8282 mask_value : float | None = DEFAULT_MASK_VALUE ,
8383):
84- check_inputs_shapes (queries , kv_pages , kv_lens , page_indices , cu_q_lens ,
85- num_seqs )
84+ validate_static_inputs (queries , kv_pages , kv_lens , page_indices , cu_q_lens ,
85+ num_seqs , sliding_window , soft_cap )
8686 if mask_value is None :
8787 mask_value = DEFAULT_MASK_VALUE
8888 _ , _ , num_combined_kv_heads , head_dim = kv_pages .shape
@@ -124,7 +124,7 @@ def ref_ragged_paged_attention(
124124
125125
126126# Expect to run these checkes during runtime.
127- def validate_inputs_on_runtime (
127+ def validate_dynamic_inputs (
128128 q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
129129 kv_pages : jax .
130130 Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
@@ -135,7 +135,8 @@ def validate_inputs_on_runtime(
135135 sliding_window : int | None = None ,
136136 soft_cap : float | None = None ,
137137):
138- check_inputs_shapes (q , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs )
138+ validate_static_inputs (q , kv_pages , kv_lens , page_indices , cu_q_lens ,
139+ num_seqs , sliding_window , soft_cap )
139140 max_num_batched_tokens = q .shape [0 ]
140141 page_size = kv_pages .shape [1 ]
141142 max_num_seqs , pages_per_seq = page_indices .shape
@@ -157,21 +158,19 @@ def validate_inputs_on_runtime(
157158 if q_len > kv_len :
158159 raise ValueError (
159160 f"{ q_len = } must be less or equal to { kv_len = } at sequence { i } ." )
160- if sliding_window is not None and sliding_window <= 0 :
161- raise ValueError (f"{ sliding_window = } must be positive." )
162- if soft_cap is not None and soft_cap == 0.0 :
163- raise ValueError (f"{ soft_cap = } must not be 0.0." )
164161
165162
166163# Expect to run these checks during compile time.
167- def check_inputs_shapes (
164+ def validate_static_inputs (
168165 q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
169166 kv_pages : jax .
170167 Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
171168 kv_lens : jax .Array , # i32[max_num_seqs]
172169 page_indices : jax .Array , # i32[max_num_seqs, pages_per_seq]
173170 cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
174171 num_seqs , # i32[1]
172+ sliding_window : int | None = None ,
173+ soft_cap : float | None = None ,
175174):
176175 _ , num_q_heads , head_dim = q .shape
177176 _ , _ , num_combined_kv_heads , head_dim_k = kv_pages .shape
@@ -198,6 +197,10 @@ def check_inputs_shapes(
198197 f" { cu_q_lens .dtype = } ." )
199198 if num_q_heads % num_kv_heads != 0 :
200199 raise ValueError (f"{ num_q_heads = } must be divisible by { num_kv_heads = } " )
200+ if sliding_window is not None and sliding_window <= 0 :
201+ raise ValueError (f"{ sliding_window = } must be positive." )
202+ if soft_cap is not None and soft_cap == 0.0 :
203+ raise ValueError (f"{ soft_cap = } must not be 0.0." )
201204
202205
203206def ragged_paged_attention_kernel (
@@ -218,6 +221,7 @@ def ragged_paged_attention_kernel(
218221 sems , # [2, 2]
219222 l_ref , # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
220223 m_ref , # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
224+ acc_ref , # [num_q_per_blk, num_q_heads_per_blk, head_dim]
221225 * ,
222226 sm_scale : float ,
223227 sliding_window : int | None = None ,
@@ -341,7 +345,7 @@ def flash_attention(
341345 v , # [num_kv_per_blk, head_dim]
342346 head_l_ref , # [num_q_per_blk * num_q_heads_per_kv_head, 128]
343347 head_m_ref , # [num_q_per_blk * num_q_heads_per_kv_head, 128]
344- head_o_ref , # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
348+ head_acc_ref , # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
345349 * ,
346350 kv_blk_idx ,
347351 ):
@@ -362,7 +366,7 @@ def flash_attention(
362366 num_q_per_blk * num_q_heads_per_kv_head ,
363367 128 ,
364368 )
365- assert head_o_ref .shape == (
369+ assert head_acc_ref .shape == (
366370 num_q_per_blk ,
367371 num_q_heads_per_kv_head ,
368372 head_dim ,
@@ -398,8 +402,8 @@ def init_scratch_ref():
398402 num_q_heads_per_kv_head ,
399403 )
400404 masked_store (
401- head_o_ref ,
402- jnp .zeros_like (head_o_ref ),
405+ head_acc_ref ,
406+ jnp .zeros_like (head_acc_ref ),
403407 store_start ,
404408 store_end ,
405409 )
@@ -457,17 +461,17 @@ def broadcast_to_shape(arr, shape):
457461 return jnp .concatenate ([arr for _ in range (shape [1 ] // arr .shape [1 ])],
458462 axis = 1 )
459463
460- o_curr = head_o_ref [...].reshape (- 1 , head_dim )
464+ o_curr = head_acc_ref [...].reshape (- 1 , head_dim )
461465 l_alpha = broadcast_to_shape (l_alpha , qkv .shape )
462466 beta = broadcast_to_shape (beta , qkv .shape )
463467 l_next_safe = broadcast_to_shape (l_next_safe , qkv .shape )
464468 out = lax .div (
465469 l_alpha * o_curr + beta * qkv ,
466470 l_next_safe ,
467- ). astype ( head_o_ref . dtype )
471+ )
468472 masked_store (
469- head_o_ref ,
470- out .reshape (head_o_ref .shape ),
473+ head_acc_ref ,
474+ out .reshape (head_acc_ref .shape ),
471475 store_start ,
472476 store_end ,
473477 )
@@ -513,7 +517,7 @@ def prefetch_next_kv_blk():
513517 v ,
514518 l_ref .at [kv_head_idx ],
515519 m_ref .at [kv_head_idx ],
516- o_ref .at [:, q_head_idx :q_head_idx + num_q_heads_per_kv_head , :],
520+ acc_ref .at [:, q_head_idx :q_head_idx + num_q_heads_per_kv_head , :],
517521 kv_blk_idx = kv_blk_idx ,
518522 )
519523 return kv_blk_idx + 1 , next_buf_idx
@@ -535,6 +539,7 @@ def prefetch_next_kv_blk():
535539 # Reset seq_idx for next kv_heads_blk if run out of seqs!
536540 seq_buf_idx_ref [0 ] = lax .select (seq_idx < num_seqs , seq_idx , 0 )
537541 seq_buf_idx_ref [1 ] = buf_idx
542+ o_ref [...] = acc_ref [...].astype (q_ref .dtype )
538543
539544
540545def cdiv (a , b ):
@@ -629,6 +634,7 @@ def ragged_paged_attention(
629634 num_seqs: the dynamic number of sequences.
630635 sm_scale: the softmax scale which will be applied to the Q@K^T.
631636 sliding_window: the sliding window size for the attention.
637+ soft_cap: the logit soft cap for the attention.
632638 mask_value: mask value for causal mask.
633639 num_kv_pages_per_block: number of kv pages to be processed in one flash
634640 attention block in the pallas kernel.
@@ -639,7 +645,8 @@ def ragged_paged_attention(
639645 Returns:
640646 The output of the attention.
641647 """
642- check_inputs_shapes (q , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs )
648+ validate_static_inputs (q , kv_pages , kv_lens , page_indices , cu_q_lens ,
649+ num_seqs , sliding_window , soft_cap )
643650 if mask_value is None :
644651 mask_value = DEFAULT_MASK_VALUE
645652 num_q , num_q_heads , head_dim = q .shape
@@ -676,6 +683,10 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
676683 (num_kv_heads_per_blk , num_q_per_blk * num_q_heads_per_kv_head , 128 ),
677684 jnp .float32 ,
678685 )
686+ acc_scratch = pltpu .VMEM (
687+ (num_q_per_blk , num_q_heads_per_blk , head_dim ),
688+ jnp .float32 ,
689+ )
679690 double_buf_scratch = pltpu .VMEM (
680691 (
681692 2 , # For double buffering during DMA copies.
@@ -691,6 +702,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
691702 pltpu .SemaphoreType .DMA ((2 ,)), # Semaphores for double buffers.
692703 lm_scratch , # l_ref
693704 lm_scratch , # m_ref
705+ acc_scratch ,
694706 ]
695707 scalar_prefetches = (
696708 kv_lens ,
@@ -721,10 +733,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
721733 ),
722734 vmem_limit_bytes = vmem_limit_bytes ,
723735 ),
724- out_shape = jax .ShapeDtypeStruct (shape = q .shape , dtype = jnp . float32 ),
736+ out_shape = jax .ShapeDtypeStruct (shape = q .shape , dtype = q . dtype ),
725737 name = "ragged_paged_attention_kernel" ,
726738 )
727739
728- # TODO(jevinjiang): Use f32 acc scratch for output! So we only need
729- # to transfer output with desired dtype back to HBM.
730- return kernel (* scalar_prefetches , q , kv_pages ).astype (q .dtype )
740+ return kernel (* scalar_prefetches , q , kv_pages )
0 commit comments