diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 70a180a777..f32c0cab73 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -138,7 +138,8 @@ Please consult the documentation below to learn more about the parameters you ma ## Kernel backend -* `attention_backend`: The backend for attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses. +* `attention_backend`: The backend for multi-head attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses. Defaults to be `flashinfer`. +* `mla_backend`: The backend for multi-head latent attention, and can be one of `fa3`, `flashinfer`, or `triton`. When deploying deepseek models, this argument is for specifying the MLA backend it uses. Defaults to be `triton`. * `sampling_backend`: The backend for sampling. ## Constrained Decoding @@ -192,5 +193,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--attention_backend flashinfer` instead for switching on flashfiner mla!** +* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--mla-backend flashinfer` instead for switching on flashfiner mla!** * `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when flashinfer is used as mla backend turned on. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 0e4cea70e5..d04cecbcdd 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument. +- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--mla-backend` argument. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. @@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec ``` - The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script. - The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- Currently when using flashinfer mla wrapper (`--attention-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta. +- Currently when using flashinfer mla wrapper (`--mla-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 107765eded..a299847784 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -64,6 +64,7 @@ # Put some global args for easy access global_server_args_dict = { "attention_backend": ServerArgs.attention_backend, + "mla_backend": ServerArgs.mla_backend, "sampling_backend": ServerArgs.sampling_backend, "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, "disable_mla": ServerArgs.disable_mla, @@ -1434,9 +1435,9 @@ def get_model_worker_batch(self) -> ModelWorkerBatch: # Create seq_lens_cpu when needed if ( - global_server_args_dict["attention_backend"] == "flashinfer_mla" + global_server_args_dict["mla_backend"] == "flashinfer" or global_server_args_dict["enable_flashmla"] - or global_server_args_dict["attention_backend"] == "fa3" + or global_server_args_dict["mla_backend"] == "fa3" ): seq_lens_cpu = self.seq_lens.cpu() else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3b8b769a67..43e7bf6431 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -123,6 +123,10 @@ def __init__( self.page_size = server_args.page_size self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.use_mla = ( + self.model_config.attention_arch == AttentionArch.MLA + and not server_args.disable_mla + ) # Model-specific adjustment self.model_specific_adjustment() @@ -139,6 +143,7 @@ def __init__( global_server_args_dict.update( { "attention_backend": server_args.attention_backend, + "mla_backend": server_args.mla_backend, "sampling_backend": server_args.sampling_backend, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "disable_mla": server_args.disable_mla, @@ -203,11 +208,17 @@ def initialize(self, min_per_gpu_memory: float): ) if self.device == "cuda": self.init_cublas() - self.init_attention_backend() + if not self.use_mla: + self.init_attention_backend() + else: + self.init_mla_backend() self.init_cuda_graphs() else: self.cuda_graph_runner = None - self.init_attention_backend() + if not self.use_mla: + self.init_attention_backend() + else: + self.init_mla_backend() # auxiliary hidden capture mode. TODO: expose this to server args? if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: @@ -216,31 +227,27 @@ def initialize(self, min_per_gpu_memory: float): def model_specific_adjustment(self): server_args = self.server_args - if ( - self.model_config.attention_arch == AttentionArch.MLA - and not server_args.disable_mla - ): + if self.use_mla: # TODO: add MLA optimization on CPU if server_args.device != "cpu": if ( - server_args.attention_backend == "flashinfer" + server_args.mla_backend == "flashinfer" or server_args.enable_flashinfer_mla ): logger.info( "MLA optimization is turned on. Use flashinfer backend." ) - # Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend - server_args.attention_backend = "flashinfer_mla" + server_args.mla_backend == "flashinfer" elif server_args.enable_flashmla: logger.info("MLA optimization is turned on. Use flashmla decode.") - server_args.attention_backend = "flashmla" - elif server_args.attention_backend == "fa3": + server_args.mla_backend = "flashmla" + elif server_args.mla_backend == "fa3": logger.info( f"MLA optimization is turned on. Use flash attention 3 backend." ) else: logger.info("MLA optimization is turned on. Use triton backend.") - server_args.attention_backend = "triton" + server_args.mla_backend = "triton" if server_args.enable_double_sparsity: logger.info( @@ -859,20 +866,56 @@ def init_attention_backend(self): ) self.attn_backend = TorchNativeAttnBackend(self) - elif self.server_args.attention_backend == "flashinfer_mla": + elif self.server_args.attention_backend == "fa3": + assert torch.cuda.get_device_capability()[0] >= 9, ( + "FlashAttention v3 Backend requires SM>=90. " + "Please use `--attention-backend flashinfer`." + ) + logger.warning( + "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported." + ) + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + self.attn_backend = FlashAttentionBackend(self) + else: + raise ValueError( + f"Invalid attention backend: {self.server_args.attention_backend}" + ) + + def init_mla_backend(self): + """Init mla kernel backend.""" + if self.server_args.mla_backend == "flashinfer": from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, ) self.attn_backend = FlashInferMLAAttnBackend(self) - elif self.server_args.attention_backend == "flashmla": + elif self.server_args.mla_backend == "triton": + assert self.sliding_window_size is None, ( + "Window attention is not supported in the triton attention backend. " + "Please use fa3 or flashinfer mla backend." + ) + assert not self.model_config.is_encoder_decoder, ( + "Cross attention is not supported in the triton attention backend. " + "Please use fa3 or flashinfer mla backend." + ) + assert not self.server_args.enable_double_sparsity, { + "Double sparsity is not supported in the triton attention backend when enabling MLA. " + "Please use fa3 or flashinfer mla backend." + } + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + + self.attn_backend = TritonAttnBackend(self) + elif self.server_args.mla_backend == "flashmla": from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend self.attn_backend = FlashMLABackend(self) - elif self.server_args.attention_backend == "fa3": + elif self.server_args.mla_backend == "fa3": assert torch.cuda.get_device_capability()[0] >= 9, ( "FlashAttention v3 Backend requires SM>=90. " - "Please use `--attention-backend flashinfer`." + "Please use flashinfer or triton mla backend." ) logger.warning( "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported." @@ -883,9 +926,7 @@ def init_attention_backend(self): self.attn_backend = FlashAttentionBackend(self) else: - raise ValueError( - f"Invalid attention backend: {self.server_args.attention_backend}" - ) + raise ValueError(f"Invalid mla backend: {self.server_args.mla_backend}") def init_double_sparsity_channel_config(self, selected_channel): selected_channel = "." + selected_channel + "_proj" diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2fcd193d80..6cc3fb06f7 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -687,11 +687,11 @@ def __init__( self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] - self.attention_backend = global_server_args_dict["attention_backend"] + self.mla_backend = global_server_args_dict["mla_backend"] self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" def no_absorb(self, forward_batch: ForwardBatch) -> bool: - if self.attention_backend == "flashinfer_mla": + if self.mla_backend == "flashinfer_mla": # Flashinfer MLA: Do not absorb when enabling ragged prefill return ( not self.flashinfer_mla_disable_ragged @@ -700,7 +700,7 @@ def no_absorb(self, forward_batch: ForwardBatch) -> bool: and not forward_batch.forward_mode.is_draft_extend() and sum(forward_batch.extend_prefix_lens_cpu) == 0 ) - elif self.attention_backend == "fa3": + elif self.mla_backend == "fa3": # Flash Attention: Keep absorbing for all extend/decode return False else: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 85f65eb74d..314643ef8f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -127,6 +127,7 @@ class ServerArgs: # Kernel backend attention_backend: Optional[str] = None + mla_backend: Optional[str] = None sampling_backend: Optional[str] = None grammar_backend: Optional[str] = None @@ -265,12 +266,17 @@ def __post_init__(self): # Choose kernel backends if self.device == "hpu": self.attention_backend = "torch_native" + self.mla_backend = "triton" self.sampling_backend = "pytorch" if self.attention_backend is None: self.attention_backend = ( "flashinfer" if is_flashinfer_available() else "triton" ) + + if self.mla_backend is None: + self.mla_backend = "triton" + if self.sampling_backend is None: self.sampling_backend = ( "flashinfer" if is_flashinfer_available() else "pytorch" @@ -819,6 +825,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) + parser.add_argument( + "--mla-backend", + type=str, + choices=["flashinfer", "triton", "fa3"], + default=ServerArgs.mla_backend, + help="Choose the kernels for multi-head latent attention layers.", + ) parser.add_argument( "--sampling-backend", type=str, @@ -836,7 +849,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-flashinfer-mla", action="store_true", - help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!", + help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--mla-backend flashinfer' instead for switching on flashfiner mla!", ) parser.add_argument( "--enable-flashmla", diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 4f0953e6a3..fb8f0b0416 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -26,7 +26,7 @@ def setUpClass(cls): "--enable-torch-compile", "--cuda-graph-max-bs", "2", - "--attention-backend", + "--mla-backend", "flashinfer", ] ) @@ -70,7 +70,7 @@ def setUpClass(cls): "--disable-cuda-graph", "--cuda-graph-max-bs", "4", - "--attention-backend", + "--mla-backend", "flashinfer", ] ) @@ -126,7 +126,7 @@ def setUpClass(cls): "1", "--speculative-num-draft-tokens", "4", - "--attention-backend", + "--mla-backend", "flashinfer", ] )