Skip to content

Support FlashMLA backend cuda graph #4514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 19, 2025
214 changes: 184 additions & 30 deletions python/sglang/srt/layers/attention/flashmla_backend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

"""
Support attention backend for flashMLA.
Support attention backend for FlashMLA.

Current initial integration of FlashMLA shows normal accuracy, but performance is slightly lacking.
#TODO
Support FlashMLA decode with cudagraph
Enable speculative sampling in FlashMLA
Integrate FA3 prefill
"""


from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union

import torch
Expand All @@ -28,10 +25,30 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInfo


# FlashMLA only supports pagesize=64
PAGE_SIZE = 64
# TODO The current setup is hard-coded and will be changed after integrating with MTP.
Q_LEN = 1


@dataclass
class FlashMLADecodeMetadata:
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None

def __init__(
self,
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_splits: Optional[torch.Tensor] = None,
block_kv_indices: Optional[torch.Tensor] = None,
):
self.flashmla_metadata = flashmla_metadata
self.num_splits = num_splits
self.block_kv_indices = block_kv_indices


class FlashMLABackend(FlashInferMLAAttnBackend):
Expand All @@ -58,6 +75,7 @@ def __init__(
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.forward_metadata: Union[FlashMLADecodeMetadata] = None
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
Expand All @@ -67,6 +85,163 @@ def __init__(
self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim

def init_forward_metadata(self, forward_batch: ForwardBatch):

bs = forward_batch.batch_size
spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
else:
super().init_forward_metadata(forward_batch)
else:
super().init_forward_metadata(forward_batch)

def init_cuda_graph_state(
self,
max_bs: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
cuda_graph_kv_indices = torch.full(
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
1,
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = block_kv_indices

self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.cuda_graph_kv_indices = cuda_graph_kv_indices

def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)

create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = FlashMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)

else:
super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)

def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):

if forward_mode.is_decode_or_idle():
seq_lens = seq_lens[:bs]
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid CPU-GPU synchronization by avoiding the use of seq_lens.max().item().
Can you derive this value from seq_lens_cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this problem. I was too busy at work today and didn't have time to modify it. I will fix it tomorrow.

create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]

else:
super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)

def forward_decode(
self,
q: torch.Tensor,
Expand All @@ -88,39 +263,18 @@ def forward_decode(
v,
)
bs = forward_batch.batch_size

max_seqlen_pad = triton.cdiv(forward_batch.seq_lens.max().item(), PAGE_SIZE)
flashmla_index = torch.full(
(bs, max_seqlen_pad), -1, dtype=torch.int32, device=q.device
)
create_flashmla_kv_indices_triton[(bs,)](
self.indices_updater_decode.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
flashmla_index,
self.indices_updater_decode.req_to_token.size(1),
flashmla_index.size(1),
max_seqlen_pad,
)

mla_metadata, mla_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
1 * self.num_q_heads // self.num_kv_heads,
self.num_kv_heads,
)

k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)

reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)

o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=flashmla_index,
block_table=self.forward_metadata.block_kv_indices,
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=mla_metadata,
num_splits=mla_splits,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=False,
)
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/layers/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def create_flashmla_kv_indices_triton(
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
kv_indices_ptr_stride: tl.constexpr,
max_pagesize: tl.constexpr,
):
PAGED_SIZE: tl.constexpr = 64
BLOCK_SIZE: tl.constexpr = 4096
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def __post_init__(self):
assert self.chunked_prefill_size % self.page_size == 0

if self.enable_flashmla is True:
assert self.page_size == 64, "FlashMLA only support page_size=64"
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
self.page_size = 64
# Set cuda graph max batch size
if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
Expand Down
Loading