Skip to content

Commit 4b94f5a

Browse files
Edenzzzztarinkk
authored andcommitted
Fix "Avoid computing lse in Ragged Prefill when there's no prefix match" (sgl-project#5555)
1 parent 47fd2d6 commit 4b94f5a

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

docs/backend/server_arguments.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
122122
| `chunked_prefill_size` | Perform prefill in chunks of this size. Larger sizes speed up prefill but increase VRAM usage. Decrease if CUDA runs out of memory. | None |
123123
| `max_prefill_tokens` | Token budget for how many tokens can be accepted in one prefill batch. The actual limit is the max of this value and `context_length`. | `16384` |
124124
| `schedule_policy` | The scheduling policy to control how waiting prefill requests are processed by a single engine. | `"fcfs"` |
125-
| `schedule_conservativeness` | Controls how conservative the server is when accepting new requests. High conservativeness may cause starvation; low conservativeness may reduce performance. | `1.0` |
125+
| `schedule_conservativeness` | Controls how conservative the server is when accepting new prefill requests. High conservativeness may cause starvation; low conservativeness may slow down decode. | `1.0` |
126126
| `cpu_offload_gb` | Amount of RAM (in GB) to reserve for offloading model parameters to the CPU. | `0` |
127127

128128
## Other runtime options
@@ -219,5 +219,5 @@ Please consult the documentation below and [server_args.py](https://github.com/s
219219
| `cuda_graph_bs` | The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. | None |
220220
| `torchao_config` | Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | `int8dq` |
221221
| `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` |
222-
| `flashinfer_mla_disable_ragged` | Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend. | `False` |
222+
| `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` |
223223
| `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` |

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def forward_extend(
418418

419419
logits_soft_cap = layer.logit_cap
420420

421+
q = q.contiguous()
421422
if not self.forward_metadata.use_ragged:
422423
if k is not None:
423424
assert v is not None
@@ -427,7 +428,7 @@ def forward_extend(
427428
)
428429

429430
o = prefill_wrapper_paged.forward(
430-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
431+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
431432
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
432433
causal=not layer.is_cross_attention,
433434
sm_scale=layer.scaling,
@@ -437,20 +438,27 @@ def forward_extend(
437438
v_scale=layer.v_scale,
438439
)
439440
else:
440-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
441-
q.view(-1, layer.tp_q_head_num, layer.head_dim),
442-
k.view(-1, layer.tp_k_head_num, layer.head_dim),
443-
v.view(-1, layer.tp_v_head_num, layer.head_dim),
444-
causal=True,
445-
sm_scale=layer.scaling,
446-
logits_soft_cap=logits_soft_cap,
447-
)
448-
449441
if self.forward_metadata.extend_no_prefix:
450-
o = o1
442+
o = self.prefill_wrapper_ragged.forward(
443+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
444+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
445+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
446+
causal=True,
447+
sm_scale=layer.scaling,
448+
logits_soft_cap=logits_soft_cap,
449+
)
450+
451451
else:
452+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
453+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
454+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
455+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
456+
causal=True,
457+
sm_scale=layer.scaling,
458+
logits_soft_cap=logits_soft_cap,
459+
)
452460
o2, s2 = prefill_wrapper_paged.forward_return_lse(
453-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
461+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
454462
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
455463
causal=False,
456464
sm_scale=layer.scaling,

python/sglang/srt/layers/attention/flashinfer_mla_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def forward_extend(
355355

356356
if self.forward_metadata.use_ragged:
357357
# ragged prefill
358-
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
358+
o = self.prefill_wrapper_ragged.forward(
359359
qall,
360360
k.view(-1, layer.tp_k_head_num, layer.head_dim),
361361
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),

0 commit comments

Comments
 (0)