diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index c6e1e4850d..76206aaeef 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -122,7 +122,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `chunked_prefill_size` | Perform prefill in chunks of this size. Larger sizes speed up prefill but increase VRAM usage. Decrease if CUDA runs out of memory. | None | | `max_prefill_tokens` | Token budget for how many tokens can be accepted in one prefill batch. The actual limit is the max of this value and `context_length`. | `16384` | | `schedule_policy` | The scheduling policy to control how waiting prefill requests are processed by a single engine. | `"fcfs"` | -| `schedule_conservativeness` | Controls how conservative the server is when accepting new requests. High conservativeness may cause starvation; low conservativeness may reduce performance. | `1.0` | +| `schedule_conservativeness` | Controls how conservative the server is when accepting new prefill requests. High conservativeness may cause starvation; low conservativeness may slow down decode. | `1.0` | | `cpu_offload_gb` | Amount of RAM (in GB) to reserve for offloading model parameters to the CPU. | `0` | ## Other runtime options @@ -219,5 +219,5 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `cuda_graph_bs` | The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. | None | | `torchao_config` | Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | `int8dq` | | `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` | -| `flashinfer_mla_disable_ragged` | Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend. | `False` | +| `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` | | `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` | diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index d568170ce2..1c254c4fa5 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -418,6 +418,7 @@ def forward_extend( logits_soft_cap = layer.logit_cap + q = q.contiguous() if not self.forward_metadata.use_ragged: if k is not None: assert v is not None @@ -427,7 +428,7 @@ def forward_extend( ) o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + q.view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), causal=not layer.is_cross_attention, sm_scale=layer.scaling, @@ -437,20 +438,27 @@ def forward_extend( v_scale=layer.v_scale, ) else: - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( - q.view(-1, layer.tp_q_head_num, layer.head_dim), - k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.head_dim), - causal=True, - sm_scale=layer.scaling, - logits_soft_cap=logits_soft_cap, - ) - if self.forward_metadata.extend_no_prefix: - o = o1 + o = self.prefill_wrapper_ragged.forward( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) o2, s2 = prefill_wrapper_paged.forward_return_lse( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + q.view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), causal=False, sm_scale=layer.scaling, diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index dded9a9702..58982f3e82 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -355,7 +355,7 @@ def forward_extend( if self.forward_metadata.use_ragged: # ragged prefill - o, _ = self.prefill_wrapper_ragged.forward_return_lse( + o = self.prefill_wrapper_ragged.forward( qall, k.view(-1, layer.tp_k_head_num, layer.head_dim), v.view(-1, layer.tp_k_head_num, layer.v_head_dim),