@@ -18,6 +18,7 @@ def _fwd_kernel_with_v(
18
18
V ,
19
19
sm_scale ,
20
20
B_Start_Loc ,
21
+ B_Kv_Start_Loc ,
21
22
B_Seqlen , # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度
22
23
Out ,
23
24
stride_q_bs ,
@@ -44,7 +45,8 @@ def _fwd_kernel_with_v(
44
45
45
46
cur_k_head = cur_head
46
47
47
- cur_batch_in_all_start_index = tl .load (B_Start_Loc + cur_batch )
48
+ cur_batch_in_q_start_index = tl .load (B_Start_Loc + cur_batch )
49
+ cur_batch_in_kv_start_index = tl .load (B_Kv_Start_Loc + cur_batch )
48
50
prompt_cache_len = tl .load (b_prompt_cache_len + cur_batch )
49
51
cur_batch_seq_len = tl .load (B_Seqlen + cur_batch ) - prompt_cache_len
50
52
@@ -55,9 +57,9 @@ def _fwd_kernel_with_v(
55
57
offs_d = tl .arange (0 , BLOCK_DMODEL )
56
58
offs_rope_d = tl .arange (0 , BLOCK_ROPE_DMODEL )
57
59
offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
58
- off_q = (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_q_bs + cur_head * stride_q_h + offs_d [None , :]
60
+ off_q = (cur_batch_in_q_start_index + offs_m [:, None ]) * stride_q_bs + cur_head * stride_q_h + offs_d [None , :]
59
61
off_q_rope = (
60
- (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_q_rope_bs
62
+ (cur_batch_in_q_start_index + offs_m [:, None ]) * stride_q_rope_bs
61
63
+ cur_head * stride_q_rope_h
62
64
+ offs_rope_d [None , :]
63
65
)
@@ -84,12 +86,12 @@ def _fwd_kernel_with_v(
84
86
start_n = tl .multiple_of (start_n , BLOCK_N )
85
87
# -- compute qk ----
86
88
k = tl .load (
87
- k_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_k_bs ,
89
+ k_ptrs + (cur_batch_in_kv_start_index + start_n ) * stride_k_bs ,
88
90
mask = (start_n + offs_n [None , :]) < block_end_loc ,
89
91
other = 0.0 ,
90
92
)
91
93
k_rope = tl .load (
92
- k_rope_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_k_rope_bs ,
94
+ k_rope_ptrs + (cur_batch_in_kv_start_index + start_n ) * stride_k_rope_bs ,
93
95
mask = (start_n + offs_n [None , :]) < block_end_loc ,
94
96
other = 0.0 ,
95
97
)
@@ -119,7 +121,7 @@ def _fwd_kernel_with_v(
119
121
acc = acc * acc_scale [:, None ]
120
122
# update acc
121
123
v = tl .load (
122
- v_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_vbs ,
124
+ v_ptrs + (cur_batch_in_kv_start_index + start_n ) * stride_vbs ,
123
125
mask = (start_n + offs_n [:, None ]) < block_end_loc ,
124
126
other = 0.0 ,
125
127
)
@@ -129,7 +131,7 @@ def _fwd_kernel_with_v(
129
131
l_i = l_i_new
130
132
m_i = m_i_new
131
133
# initialize pointers to output
132
- off_o = (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_obs + cur_head * stride_oh + offs_d [None , :]
134
+ off_o = (cur_batch_in_q_start_index + offs_m [:, None ]) * stride_obs + cur_head * stride_oh + offs_d [None , :]
133
135
out_ptrs = Out + off_o
134
136
tl .store (out_ptrs , acc , mask = offs_m [:, None ] < cur_batch_seq_len )
135
137
return
@@ -144,6 +146,7 @@ def context_attention_fwd_with_v(
144
146
v ,
145
147
o ,
146
148
b_start_loc ,
149
+ b_kv_start_loc ,
147
150
b_seq_len ,
148
151
b_prompt_cache_len ,
149
152
max_input_len ,
@@ -181,6 +184,7 @@ def context_attention_fwd_with_v(
181
184
v ,
182
185
sm_scale ,
183
186
b_start_loc ,
187
+ b_kv_start_loc ,
184
188
b_seq_len ,
185
189
o ,
186
190
q_nope .stride (0 ),
@@ -204,3 +208,78 @@ def context_attention_fwd_with_v(
204
208
num_stages = 1 ,
205
209
)
206
210
return
211
+
212
+
213
+ if __name__ == "__main__" :
214
+ import torch
215
+ import flashinfer
216
+
217
+ Z , N_CTX , H , D_HEAD , ROPE_HEAD = 32 , 1024 , 16 , 128 , 64
218
+ dtype = torch .bfloat16
219
+
220
+ k_nope = torch .randn ((Z * N_CTX , H , D_HEAD ), dtype = dtype , device = "cuda" )
221
+ k_rope = torch .randn ((Z * N_CTX , 1 , ROPE_HEAD ), dtype = dtype , device = "cuda" )
222
+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , H , dim = - 2 )], dim = - 1 )
223
+ v = torch .randn ((Z * N_CTX , H , D_HEAD ), dtype = dtype , device = "cuda" )
224
+
225
+ max_input_len = Z * N_CTX
226
+ softmax_scale = 0.117
227
+ b_seq_len = torch .ones ((Z ,), dtype = torch .int32 , device = "cuda" ) * N_CTX
228
+ b_prompt_cache_len = torch .zeros_like (b_seq_len , dtype = torch .int32 , device = "cuda" )
229
+ b_prompt_cache_len = torch .randint_like (b_seq_len , high = N_CTX - 1 , dtype = torch .int32 , device = "cuda" )
230
+ q_lens = b_seq_len - b_prompt_cache_len
231
+ q_start_loc = q_lens .cumsum (0 ) - q_lens
232
+ kv_start_loc = b_seq_len .cumsum (0 ) - b_seq_len
233
+
234
+ q_nope = torch .randn ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
235
+ q_rope = torch .randn ((q_lens .sum (), H , ROPE_HEAD ), dtype = dtype , device = "cuda" )
236
+ q = torch .cat ([q_nope , q_rope ], dim = - 1 )
237
+
238
+ o = torch .empty ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
239
+ o1 = torch .empty ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
240
+ o2 = torch .empty ((q_lens .sum (), H , D_HEAD ), dtype = dtype , device = "cuda" )
241
+
242
+ fn1 = lambda : context_attention_fwd_with_v (
243
+ q_nope ,
244
+ q_rope ,
245
+ k_nope ,
246
+ k_rope ,
247
+ v ,
248
+ o ,
249
+ q_start_loc ,
250
+ kv_start_loc ,
251
+ b_seq_len ,
252
+ b_prompt_cache_len ,
253
+ max_input_len ,
254
+ softmax_scale ,
255
+ )
256
+
257
+ q_starts = torch .zeros ((Z + 1 ,)).int ().cuda ()
258
+ q_starts [1 :] = torch .cumsum (b_seq_len - b_prompt_cache_len , dim = 0 )
259
+ kv_starts = torch .zeros_like (q_starts )
260
+ kv_starts [1 :] = torch .cumsum (b_seq_len , dim = 0 )
261
+ kv_layout = "NHD"
262
+ batch_size = Z
263
+ q_indptr = q_starts
264
+ kv_indptr = kv_starts
265
+ workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 ).to (0 )
266
+ wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (workspace_buffer , kv_layout )
267
+ wrapper .plan (
268
+ qo_indptr = q_indptr ,
269
+ kv_indptr = kv_indptr ,
270
+ num_qo_heads = H ,
271
+ num_kv_heads = H ,
272
+ head_dim_qk = D_HEAD + ROPE_HEAD ,
273
+ head_dim_vo = D_HEAD ,
274
+ q_data_type = dtype ,
275
+ causal = True ,
276
+ sm_scale = softmax_scale ,
277
+ )
278
+ fn2 = lambda : wrapper .run (q , k , v , out = o1 )
279
+
280
+ ms1 = triton .testing .do_bench (fn1 )
281
+ ms2 = triton .testing .do_bench (fn2 )
282
+ cos_sim1 = F .cosine_similarity (o , o1 ).mean ()
283
+ print (cos_sim1 )
284
+ print (ms1 )
285
+ print (ms2 )
0 commit comments