Skip to content

Commit e4155e9

Browse files
authored
Add flash_attn_varlen_func to sgl-kernel (#5315)
1 parent 1b1b47a commit e4155e9

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

sgl-kernel/python/sgl_kernel/flash_attn.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,75 @@ def flash_attn_with_kvcache(
204204
)
205205
# return (out, softmax_lse) if return_softmax_lse else out
206206
return (out, softmax_lse, *rest) if return_softmax_lse else out
207+
208+
209+
def flash_attn_varlen_func(
210+
q,
211+
k,
212+
v,
213+
cu_seqlens_q,
214+
cu_seqlens_k,
215+
max_seqlen_q,
216+
max_seqlen_k,
217+
seqused_q=None,
218+
seqused_k=None,
219+
softmax_scale=None,
220+
causal=False,
221+
qv=None,
222+
q_descale=None,
223+
k_descale=None,
224+
v_descale=None,
225+
window_size=(-1, -1),
226+
softcap=0.0,
227+
num_splits=1,
228+
pack_gqa=None,
229+
sm_margin=0,
230+
return_softmax_lse=False,
231+
):
232+
if not is_fa3_supported():
233+
raise NotImplementedError(
234+
"flash_attn at sgl-kernel is only supported on sm90 and above"
235+
)
236+
237+
if softmax_scale is None:
238+
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
239+
-0.5
240+
)
241+
242+
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
243+
q,
244+
k,
245+
v,
246+
None, # k_new
247+
None, # v_new
248+
qv, # qv
249+
None, # out
250+
cu_seqlens_q,
251+
cu_seqlens_k,
252+
None, # cu_seqlens_k_new
253+
seqused_q,
254+
seqused_k,
255+
max_seqlen_q,
256+
max_seqlen_k,
257+
None, # page_table,
258+
None, # kv_batch_idx
259+
None, # leftpad_k
260+
None, # rotary cos
261+
None, # rotary sin
262+
None, # seqlens_rotary
263+
q_descale,
264+
k_descale,
265+
v_descale,
266+
softmax_scale,
267+
causal,
268+
window_size[0],
269+
window_size[1],
270+
softcap,
271+
is_rotary_interleaved=False,
272+
scheduler_metadata=None,
273+
num_splits=num_splits,
274+
pack_gqa=pack_gqa,
275+
sm_margin=sm_margin,
276+
)
277+
278+
return (out, softmax_lse, *rest) if return_softmax_lse else out

0 commit comments

Comments
 (0)