Skip to content

Commit 5c3832e

Browse files
EdenzzzzFridge003
authored andcommitted
Avoid computing lse in Ragged Prefill when there's no prefix. (sgl-project#5476)
Co-authored-by: Baizhou Zhang <[email protected]>
1 parent 39edc9a commit 5c3832e

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

docs/backend/server_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
192192
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
193193
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
194194
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
195-
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
195+
* `flashinfer_mla_disable_ragged`: Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend.
196196
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,18 +425,25 @@ def forward_extend(
425425
v_scale=v_scale,
426426
)
427427
else:
428-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
429-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
430-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
431-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
432-
causal=True,
433-
sm_scale=layer.scaling,
434-
logits_soft_cap=logits_soft_cap,
435-
)
436-
437428
if self.forward_metadata.extend_no_prefix:
438-
o = o1
429+
o = prefill_wrapper_paged.forward(
430+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
431+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
432+
causal=not layer.is_cross_attention,
433+
sm_scale=layer.scaling,
434+
logits_soft_cap=logits_soft_cap,
435+
k_scale=k_scale,
436+
v_scale=v_scale,
437+
)
439438
else:
439+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
440+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
441+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
442+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
443+
causal=True,
444+
sm_scale=layer.scaling,
445+
logits_soft_cap=logits_soft_cap,
446+
)
440447
o2, s2 = prefill_wrapper_paged.forward_return_lse(
441448
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
442449
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),

python/sglang/srt/layers/attention/flashinfer_mla_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def forward_extend(
348348

349349
if self.forward_metadata.use_ragged:
350350
# ragged prefill
351-
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
351+
o = self.prefill_wrapper_ragged.forward(
352352
qall,
353353
k.view(-1, layer.tp_k_head_num, layer.head_dim),
354354
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),

0 commit comments

Comments
 (0)