Skip to content

[Feature] Support FA3 backend for MLA #4831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 171 additions & 73 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

import torch

from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode

if TYPE_CHECKING:
Expand Down Expand Up @@ -58,6 +60,9 @@ def __init__(
self.decode_cuda_graph_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.page_size = model_runner.page_size
self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"])

def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations."""
Expand Down Expand Up @@ -117,23 +122,30 @@ def forward_extend(
forward_batch: ForwardBatch,
save_kv_cache=True,
):
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)

if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)

# Use precomputed metadata
metadata = self.forward_metadata

# # Use Flash Attention for prefill
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
Expand All @@ -142,36 +154,72 @@ def forward_extend(
if layer.sliding_window_size is not None
else (-1, -1)
)
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]

key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)

page_table = metadata.page_table

o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
# # Use Flash Attention for prefill
if not self.use_mla:
# Do multi-head attention
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

def forward_decode(
self,
Expand All @@ -184,24 +232,29 @@ def forward_decode(
) -> torch.Tensor:
"""Forward pass with FlashAttention using precomputed metadata."""
# Save KV cache if needed
if k is not None and v is not None and save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)

# Get KV cache
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]
# Use precomputed metadata
metadata = self.forward_metadata

# Pre-reshape query tensor
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
# here is two side inclusive
Expand All @@ -210,33 +263,79 @@ def forward_decode(
if layer.sliding_window_size is not None
else (-1, -1)
)
# Run attention with precomputed values
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)

page_table = metadata.page_table

o = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
if not self.use_mla:
# Do multi-head attention

# Get KV cache
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
key_cache, value_cache = kv_cache[0], kv_cache[1]
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)

# Pre-reshape query tensor
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)

# Run attention with precomputed values
o = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_rope = kv_cache[:, :, layer.v_head_dim :]
c_kv = kv_cache[:, :, : layer.v_head_dim]
k_rope_cache = k_rope.view(
-1,
self.page_size,
layer.tp_k_head_num,
layer.head_dim - layer.v_head_dim,
)
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)

q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]

o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

def init_cuda_graph_state(self, max_bs: int):
"""Initialize CUDA graph state for the attention backend.
Expand Down Expand Up @@ -286,7 +385,6 @@ def init_forward_metadata_capture_cuda_graph(
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, :
]

if forward_mode == ForwardMode.DECODE:
# Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange(
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def model_specific_adjustment(self):
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend == "fa3":
logger.info(
f"MLA optimization is turned on. Use flash attention 3 backend."
)
else:
logger.info("MLA optimization is turned on. Use triton backend.")
server_args.attention_backend = "triton"
Expand Down Expand Up @@ -879,7 +883,7 @@ def init_attention_backend(self):
"Please use `--attention-backend flashinfer`."
)
logger.warning(
"FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
)
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def __init__(
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"

def no_absorb(self, forward_batch: ForwardBatch) -> bool:
Expand All @@ -667,6 +668,9 @@ def no_absorb(self, forward_batch: ForwardBatch) -> bool:
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
)
elif self.attention_backend == "fa3":
# Flash Attention: Keep absorbing for all extend/decode
return False
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
Expand Down
Loading