Skip to content

Add mla-backend argument #5047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ Please consult the documentation below to learn more about the parameters you ma

## Kernel backend

* `attention_backend`: The backend for attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses.
* `attention_backend`: The backend for multi-head attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses. Defaults to be `flashinfer`.
* `mla_backend`: The backend for multi-head latent attention, and can be one of `fa3`, `flashinfer`, or `triton`. When deploying deepseek models, this argument is for specifying the MLA backend it uses. Defaults to be `triton`.
* `sampling_backend`: The backend for sampling.

## Constrained Decoding
Expand Down Expand Up @@ -192,5 +193,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--attention_backend flashinfer` instead for switching on flashfiner mla!**
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--mla-backend flashinfer` instead for switching on flashfiner mla!**
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when flashinfer is used as mla backend turned on.
4 changes: 2 additions & 2 deletions docs/references/deepseek.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be

- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.

- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--mla-backend` argument.

- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.

Expand Down Expand Up @@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec
```
- The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script.
- The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
- Currently when using flashinfer mla wrapper (`--attention-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta.
- Currently when using flashinfer mla wrapper (`--mla-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta.
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
# Put some global args for easy access
global_server_args_dict = {
"attention_backend": ServerArgs.attention_backend,
"mla_backend": ServerArgs.mla_backend,
"sampling_backend": ServerArgs.sampling_backend,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"disable_mla": ServerArgs.disable_mla,
Expand Down Expand Up @@ -1434,9 +1435,9 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:

# Create seq_lens_cpu when needed
if (
global_server_args_dict["attention_backend"] == "flashinfer_mla"
global_server_args_dict["mla_backend"] == "flashinfer"
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
or global_server_args_dict["mla_backend"] == "fa3"
):
seq_lens_cpu = self.seq_lens.cpu()
else:
Expand Down
79 changes: 60 additions & 19 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def __init__(
self.page_size = server_args.page_size
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.use_mla = (
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
)

# Model-specific adjustment
self.model_specific_adjustment()
Expand All @@ -139,6 +143,7 @@ def __init__(
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
"mla_backend": server_args.mla_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"disable_mla": server_args.disable_mla,
Expand Down Expand Up @@ -203,11 +208,17 @@ def initialize(self, min_per_gpu_memory: float):
)
if self.device == "cuda":
self.init_cublas()
self.init_attention_backend()
if not self.use_mla:
self.init_attention_backend()
else:
self.init_mla_backend()
self.init_cuda_graphs()
else:
self.cuda_graph_runner = None
self.init_attention_backend()
if not self.use_mla:
self.init_attention_backend()
else:
self.init_mla_backend()

# auxiliary hidden capture mode. TODO: expose this to server args?
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
Expand All @@ -216,31 +227,27 @@ def initialize(self, min_per_gpu_memory: float):
def model_specific_adjustment(self):
server_args = self.server_args

if (
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
):
if self.use_mla:
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if (
server_args.attention_backend == "flashinfer"
server_args.mla_backend == "flashinfer"
or server_args.enable_flashinfer_mla
):
logger.info(
"MLA optimization is turned on. Use flashinfer backend."
)
# Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend
server_args.attention_backend = "flashinfer_mla"
server_args.mla_backend == "flashinfer"
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend == "fa3":
server_args.mla_backend = "flashmla"
elif server_args.mla_backend == "fa3":
logger.info(
f"MLA optimization is turned on. Use flash attention 3 backend."
)
else:
logger.info("MLA optimization is turned on. Use triton backend.")
server_args.attention_backend = "triton"
server_args.mla_backend = "triton"

if server_args.enable_double_sparsity:
logger.info(
Expand Down Expand Up @@ -859,20 +866,56 @@ def init_attention_backend(self):
)

self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
elif self.server_args.attention_backend == "fa3":
assert torch.cuda.get_device_capability()[0] >= 9, (
"FlashAttention v3 Backend requires SM>=90. "
"Please use `--attention-backend flashinfer`."
)
logger.warning(
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
)
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)

self.attn_backend = FlashAttentionBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)

def init_mla_backend(self):
"""Init mla kernel backend."""
if self.server_args.mla_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)

self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "flashmla":
elif self.server_args.mla_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
"Please use fa3 or flashinfer mla backend."
)
assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. "
"Please use fa3 or flashinfer mla backend."
)
assert not self.server_args.enable_double_sparsity, {
"Double sparsity is not supported in the triton attention backend when enabling MLA. "
"Please use fa3 or flashinfer mla backend."
}
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

self.attn_backend = TritonAttnBackend(self)
elif self.server_args.mla_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend

self.attn_backend = FlashMLABackend(self)
elif self.server_args.attention_backend == "fa3":
elif self.server_args.mla_backend == "fa3":
assert torch.cuda.get_device_capability()[0] >= 9, (
"FlashAttention v3 Backend requires SM>=90. "
"Please use `--attention-backend flashinfer`."
"Please use flashinfer or triton mla backend."
)
logger.warning(
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
Expand All @@ -883,9 +926,7 @@ def init_attention_backend(self):

self.attn_backend = FlashAttentionBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
raise ValueError(f"Invalid mla backend: {self.server_args.mla_backend}")

def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,11 +687,11 @@ def __init__(
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.attention_backend = global_server_args_dict["attention_backend"]
self.mla_backend = global_server_args_dict["mla_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"

def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if self.attention_backend == "flashinfer_mla":
if self.mla_backend == "flashinfer_mla":
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not self.flashinfer_mla_disable_ragged
Expand All @@ -700,7 +700,7 @@ def no_absorb(self, forward_batch: ForwardBatch) -> bool:
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
)
elif self.attention_backend == "fa3":
elif self.mla_backend == "fa3":
# Flash Attention: Keep absorbing for all extend/decode
return False
else:
Expand Down
15 changes: 14 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class ServerArgs:

# Kernel backend
attention_backend: Optional[str] = None
mla_backend: Optional[str] = None
sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None

Expand Down Expand Up @@ -265,12 +266,17 @@ def __post_init__(self):
# Choose kernel backends
if self.device == "hpu":
self.attention_backend = "torch_native"
self.mla_backend = "triton"
self.sampling_backend = "pytorch"

if self.attention_backend is None:
self.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)

if self.mla_backend is None:
self.mla_backend = "triton"

if self.sampling_backend is None:
self.sampling_backend = (
"flashinfer" if is_flashinfer_available() else "pytorch"
Expand Down Expand Up @@ -819,6 +825,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--mla-backend",
type=str,
choices=["flashinfer", "triton", "fa3"],
default=ServerArgs.mla_backend,
help="Choose the kernels for multi-head latent attention layers.",
)
parser.add_argument(
"--sampling-backend",
type=str,
Expand All @@ -836,7 +849,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--enable-flashinfer-mla",
action="store_true",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--mla-backend flashinfer' instead for switching on flashfiner mla!",
)
parser.add_argument(
"--enable-flashmla",
Expand Down
6 changes: 3 additions & 3 deletions test/srt/test_mla_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setUpClass(cls):
"--enable-torch-compile",
"--cuda-graph-max-bs",
"2",
"--attention-backend",
"--mla-backend",
"flashinfer",
]
)
Expand Down Expand Up @@ -70,7 +70,7 @@ def setUpClass(cls):
"--disable-cuda-graph",
"--cuda-graph-max-bs",
"4",
"--attention-backend",
"--mla-backend",
"flashinfer",
]
)
Expand Down Expand Up @@ -126,7 +126,7 @@ def setUpClass(cls):
"1",
"--speculative-num-draft-tokens",
"4",
"--attention-backend",
"--mla-backend",
"flashinfer",
]
)
Expand Down
Loading