File tree Expand file tree Collapse file tree 1 file changed +72
-0
lines changed
sgl-kernel/python/sgl_kernel Expand file tree Collapse file tree 1 file changed +72
-0
lines changed Original file line number Diff line number Diff line change @@ -204,3 +204,75 @@ def flash_attn_with_kvcache(
204
204
)
205
205
# return (out, softmax_lse) if return_softmax_lse else out
206
206
return (out , softmax_lse , * rest ) if return_softmax_lse else out
207
+
208
+
209
+ def flash_attn_varlen_func (
210
+ q ,
211
+ k ,
212
+ v ,
213
+ cu_seqlens_q ,
214
+ cu_seqlens_k ,
215
+ max_seqlen_q ,
216
+ max_seqlen_k ,
217
+ seqused_q = None ,
218
+ seqused_k = None ,
219
+ softmax_scale = None ,
220
+ causal = False ,
221
+ qv = None ,
222
+ q_descale = None ,
223
+ k_descale = None ,
224
+ v_descale = None ,
225
+ window_size = (- 1 , - 1 ),
226
+ softcap = 0.0 ,
227
+ num_splits = 1 ,
228
+ pack_gqa = None ,
229
+ sm_margin = 0 ,
230
+ return_softmax_lse = False ,
231
+ ):
232
+ if not is_fa3_supported ():
233
+ raise NotImplementedError (
234
+ "flash_attn at sgl-kernel is only supported on sm90 and above"
235
+ )
236
+
237
+ if softmax_scale is None :
238
+ softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (
239
+ - 0.5
240
+ )
241
+
242
+ out , softmax_lse , * rest = torch .ops .sgl_kernel .fwd .default (
243
+ q ,
244
+ k ,
245
+ v ,
246
+ None , # k_new
247
+ None , # v_new
248
+ qv , # qv
249
+ None , # out
250
+ cu_seqlens_q ,
251
+ cu_seqlens_k ,
252
+ None , # cu_seqlens_k_new
253
+ seqused_q ,
254
+ seqused_k ,
255
+ max_seqlen_q ,
256
+ max_seqlen_k ,
257
+ None , # page_table,
258
+ None , # kv_batch_idx
259
+ None , # leftpad_k
260
+ None , # rotary cos
261
+ None , # rotary sin
262
+ None , # seqlens_rotary
263
+ q_descale ,
264
+ k_descale ,
265
+ v_descale ,
266
+ softmax_scale ,
267
+ causal ,
268
+ window_size [0 ],
269
+ window_size [1 ],
270
+ softcap ,
271
+ is_rotary_interleaved = False ,
272
+ scheduler_metadata = None ,
273
+ num_splits = num_splits ,
274
+ pack_gqa = pack_gqa ,
275
+ sm_margin = sm_margin ,
276
+ )
277
+
278
+ return (out , softmax_lse , * rest ) if return_softmax_lse else out
You can’t perform that action at this time.
0 commit comments