Skip to content

Commit 6463e23

Browse files
trevor-mWineChord
authored andcommitted
Add Cutlass MLA attention backend (sgl-project#5390)
1 parent ca14af9 commit 6463e23

File tree

7 files changed

+305
-3
lines changed

7 files changed

+305
-3
lines changed

docs/backend/server_arguments.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma
138138

139139
## Kernel backend
140140

141-
* `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend.
141+
* `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, `cutlass_mla`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend.
142142
* `sampling_backend`: The backend for sampling.
143143

144144
## Constrained Decoding
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
from __future__ import annotations
2+
3+
"""
4+
Support attention backend for Cutlass MLA.
5+
6+
"""
7+
8+
from dataclasses import dataclass
9+
from typing import TYPE_CHECKING, Optional, Union
10+
11+
import torch
12+
import triton
13+
14+
from sglang.global_config import global_config
15+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
16+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
17+
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
18+
from sglang.srt.layers.dp_attention import get_attention_tp_size
19+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
20+
from sglang.srt.utils import is_cuda
21+
22+
if TYPE_CHECKING:
23+
from sglang.srt.layers.radix_attention import RadixAttention
24+
from sglang.srt.model_executor.model_runner import ModelRunner
25+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
26+
from sglang.srt.speculative.spec_info import SpecInfo
27+
28+
_is_cuda = is_cuda()
29+
if _is_cuda:
30+
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
31+
32+
33+
# Cutlass MLA only supports pagesize=128
34+
PAGE_SIZE = 128
35+
36+
37+
@dataclass
38+
class CutlassMLADecodeMetadata:
39+
workspace: Optional[torch.Tensor] = None
40+
block_kv_indices: Optional[torch.Tensor] = None
41+
42+
def __init__(
43+
self,
44+
workspace: Optional[torch.Tensor] = None,
45+
block_kv_indices: Optional[torch.Tensor] = None,
46+
):
47+
self.workspace = workspace
48+
self.block_kv_indices = block_kv_indices
49+
50+
51+
class CutlassMLABackend(FlashInferMLAAttnBackend):
52+
"""Cutlass attention kernels."""
53+
54+
def __init__(
55+
self,
56+
model_runner: ModelRunner,
57+
skip_prefill: bool = False,
58+
kv_indptr_buf: Optional[torch.Tensor] = None,
59+
kv_last_page_len_buf: Optional[torch.Tensor] = None,
60+
):
61+
super().__init__(
62+
model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
63+
)
64+
65+
self.num_q_heads = (
66+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
67+
)
68+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
69+
get_attention_tp_size()
70+
)
71+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
72+
self.num_local_heads = (
73+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
74+
)
75+
self.forward_metadata: Union[CutlassMLADecodeMetadata] = None
76+
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
77+
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
78+
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
79+
self.v_head_dim = model_runner.model_config.v_head_dim
80+
self.scaling = model_runner.model_config.scaling
81+
self.data_type = model_runner.kv_cache_dtype
82+
self.q_data_type = model_runner.dtype
83+
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
84+
85+
def init_forward_metadata(self, forward_batch: ForwardBatch):
86+
87+
bs = forward_batch.batch_size
88+
spec_info = forward_batch.spec_info
89+
if forward_batch.forward_mode.is_decode_or_idle():
90+
if spec_info is None:
91+
max_seqlen_pad = triton.cdiv(
92+
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
93+
)
94+
block_kv_indices = torch.full(
95+
(bs, max_seqlen_pad),
96+
-1,
97+
dtype=torch.int32,
98+
device=forward_batch.seq_lens.device,
99+
)
100+
create_flashmla_kv_indices_triton[(bs,)](
101+
self.req_to_token,
102+
forward_batch.req_pool_indices,
103+
forward_batch.seq_lens,
104+
None,
105+
block_kv_indices,
106+
self.req_to_token.stride(0),
107+
max_seqlen_pad,
108+
PAGE_SIZE,
109+
)
110+
workspace_size = cutlass_mla_get_workspace_size(
111+
max_seqlen_pad * PAGE_SIZE, bs
112+
)
113+
workspace = torch.empty(
114+
workspace_size, device="cuda", dtype=torch.uint8
115+
)
116+
self.forward_metadata = CutlassMLADecodeMetadata(
117+
workspace,
118+
block_kv_indices,
119+
)
120+
else:
121+
super().init_forward_metadata(forward_batch)
122+
else:
123+
super().init_forward_metadata(forward_batch)
124+
125+
def init_cuda_graph_state(
126+
self,
127+
max_bs: int,
128+
block_kv_indices: Optional[torch.Tensor] = None,
129+
):
130+
if block_kv_indices is None:
131+
cuda_graph_kv_indices = torch.full(
132+
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
133+
1,
134+
dtype=torch.int32,
135+
device="cuda",
136+
)
137+
else:
138+
cuda_graph_kv_indices = block_kv_indices
139+
140+
workspace_size = cutlass_mla_get_workspace_size(
141+
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
142+
)
143+
self.cuda_graph_mla_workspace = torch.empty(
144+
workspace_size, device="cuda", dtype=torch.uint8
145+
)
146+
self.cuda_graph_kv_indices = cuda_graph_kv_indices
147+
148+
def init_forward_metadata_capture_cuda_graph(
149+
self,
150+
bs: int,
151+
num_tokens: int,
152+
req_pool_indices: torch.Tensor,
153+
seq_lens: torch.Tensor,
154+
encoder_lens: Optional[torch.Tensor],
155+
forward_mode: ForwardMode,
156+
spec_info: Optional[SpecInfo],
157+
):
158+
if forward_mode.is_decode_or_idle():
159+
if spec_info is None:
160+
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
161+
162+
create_flashmla_kv_indices_triton[(bs,)](
163+
self.req_to_token,
164+
req_pool_indices,
165+
seq_lens,
166+
None,
167+
self.cuda_graph_kv_indices,
168+
self.req_to_token.stride(0),
169+
self.cuda_graph_kv_indices.stride(0),
170+
PAGE_SIZE,
171+
)
172+
workspace_size = cutlass_mla_get_workspace_size(
173+
max_seqlen_pad * PAGE_SIZE, bs
174+
)
175+
self.cuda_graph_mla_workspace = torch.empty(
176+
workspace_size, device="cuda", dtype=torch.uint8
177+
)
178+
self.forward_metadata = CutlassMLADecodeMetadata(
179+
self.cuda_graph_mla_workspace,
180+
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
181+
)
182+
else:
183+
super().init_forward_metadata_capture_cuda_graph(
184+
bs,
185+
num_tokens,
186+
req_pool_indices,
187+
seq_lens,
188+
encoder_lens,
189+
forward_mode,
190+
spec_info,
191+
)
192+
193+
def init_forward_metadata_replay_cuda_graph(
194+
self,
195+
bs: int,
196+
req_pool_indices: torch.Tensor,
197+
seq_lens: torch.Tensor,
198+
seq_lens_sum: int,
199+
encoder_lens: Optional[torch.Tensor],
200+
forward_mode: ForwardMode,
201+
spec_info: Optional[SpecInfo],
202+
seq_lens_cpu: Optional[torch.Tensor],
203+
):
204+
205+
if forward_mode.is_decode_or_idle():
206+
assert seq_lens_cpu is not None
207+
seq_lens = seq_lens[:bs]
208+
seq_lens_cpu = seq_lens_cpu[:bs]
209+
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
210+
create_flashmla_kv_indices_triton[(bs,)](
211+
self.req_to_token,
212+
req_pool_indices[:bs],
213+
seq_lens,
214+
None,
215+
self.cuda_graph_kv_indices,
216+
self.req_to_token.stride(0),
217+
self.cuda_graph_kv_indices.stride(0),
218+
PAGE_SIZE,
219+
)
220+
workspace_size = cutlass_mla_get_workspace_size(
221+
max_seqlen_pad * PAGE_SIZE, bs
222+
)
223+
self.cuda_graph_mla_workspace = torch.empty(
224+
workspace_size, device="cuda", dtype=torch.uint8
225+
)
226+
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
227+
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
228+
:bs, :max_seqlen_pad
229+
]
230+
else:
231+
super().init_forward_metadata_replay_cuda_graph(
232+
bs,
233+
req_pool_indices,
234+
seq_lens,
235+
seq_lens_sum,
236+
encoder_lens,
237+
forward_mode,
238+
spec_info,
239+
seq_lens_cpu,
240+
)
241+
242+
def get_cuda_graph_seq_len_fill_value(self):
243+
return 1
244+
245+
def forward_decode(
246+
self,
247+
q: torch.Tensor,
248+
k: torch.Tensor,
249+
v: torch.Tensor,
250+
layer: RadixAttention,
251+
forward_batch: ForwardBatch,
252+
save_kv_cache: bool = True,
253+
):
254+
cache_loc = forward_batch.out_cache_loc
255+
256+
if k is not None:
257+
assert v is not None
258+
if save_kv_cache:
259+
forward_batch.token_to_kv_pool.set_kv_buffer(
260+
layer,
261+
cache_loc,
262+
k,
263+
v,
264+
)
265+
bs = forward_batch.batch_size
266+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
267+
268+
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
269+
270+
o = cutlass_mla_decode(
271+
q_nope_and_q_pe=reshape_q,
272+
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
273+
seq_lens=forward_batch.seq_lens.to(torch.int32),
274+
page_table=self.forward_metadata.block_kv_indices,
275+
workspace=self.forward_metadata.workspace,
276+
)
277+
278+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

python/sglang/srt/layers/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def create_flashmla_kv_indices_triton(
4949
kv_indices_ptr,
5050
req_to_token_ptr_stride: tl.constexpr,
5151
kv_indices_ptr_stride: tl.constexpr,
52+
PAGED_SIZE: tl.constexpr = 64,
5253
):
53-
PAGED_SIZE: tl.constexpr = 64
5454
BLOCK_SIZE: tl.constexpr = 4096
5555
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
5656
pid = tl.program_id(axis=0)

python/sglang/srt/managers/schedule_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
15151515
)
15161516
or global_server_args_dict["attention_backend"] == "flashmla"
15171517
or global_server_args_dict["attention_backend"] == "fa3"
1518+
or global_server_args_dict["attention_backend"] == "cutlass_mla"
15181519
):
15191520
seq_lens_cpu = self.seq_lens.cpu()
15201521
else:

python/sglang/srt/model_executor/model_runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def model_specific_adjustment(self):
271271
"fa3",
272272
"triton",
273273
"flashmla",
274+
"cutlass_mla",
274275
]:
275276
logger.info(
276277
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
@@ -926,6 +927,12 @@ def init_attention_backend(self):
926927
)
927928

928929
self.attn_backend = FlashAttentionBackend(self)
930+
elif self.server_args.attention_backend == "cutlass_mla":
931+
from sglang.srt.layers.attention.cutlass_mla_backend import (
932+
CutlassMLABackend,
933+
)
934+
935+
self.attn_backend = CutlassMLABackend(self)
929936
else:
930937
raise ValueError(
931938
f"Invalid attention backend: {self.server_args.attention_backend}"

python/sglang/srt/server_args.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,12 @@ def __post_init__(self):
256256
)
257257
self.page_size = 64
258258

259+
if self.attention_backend == "cutlass_mla":
260+
logger.warning(
261+
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
262+
)
263+
self.page_size = 128
264+
259265
# Set cuda graph max batch size
260266
if self.cuda_graph_max_bs is None:
261267
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -823,7 +829,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
823829
parser.add_argument(
824830
"--attention-backend",
825831
type=str,
826-
choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
832+
choices=[
833+
"flashinfer",
834+
"triton",
835+
"torch_native",
836+
"fa3",
837+
"flashmla",
838+
"cutlass_mla",
839+
],
827840
default=ServerArgs.attention_backend,
828841
help="Choose the kernels for attention layers.",
829842
)

sgl-kernel/python/sgl_kernel/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def cutlass_mla_decode(
7878
assert len(page_table.shape) == 2
7979
B_block_table, block_num = page_table.shape
8080
assert B_block_table == B_q
81+
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
8182
assert block_num % (128 / PAGE_SIZE) == 0
8283

8384
# TODO(kaixih@nvidia): support fp8
@@ -109,6 +110,8 @@ def cutlass_mla_decode(
109110
def cutlass_mla_get_workspace_size(
110111
max_seq_len: int, num_batches: int, sm_count: int = 0
111112
) -> int:
113+
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
114+
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
112115
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
113116
max_seq_len, num_batches, sm_count
114117
)

0 commit comments

Comments
 (0)