Skip to content

feat: use fa3 mla by default on hopper #5210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device

if forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode
if forward_batch.spec_info is not None:
metadata.cache_seqlens_int32 = (
Expand Down Expand Up @@ -527,7 +527,9 @@ def forward_extend(
else (-1, -1)
)
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
Expand Down Expand Up @@ -670,10 +672,13 @@ def forward_decode(
causal = not layer.is_cross_attention

k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)

if not self.use_mla:
Expand Down Expand Up @@ -834,7 +839,7 @@ def init_forward_metadata_capture_cuda_graph(
"""Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata()
device = seq_lens.device
if forward_mode.is_decode():
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
Expand Down Expand Up @@ -937,7 +942,7 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
if forward_mode.is_decode():
if forward_mode.is_decode_or_idle():
metadata = self.decode_cuda_graph_metadata[bs]

if spec_info is not None:
Expand Down
25 changes: 21 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
is_cuda,
is_flashinfer_available,
is_hip,
is_hopper_with_cuda_12_3,
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes,
Expand Down Expand Up @@ -245,7 +246,16 @@ def model_specific_adjustment(self):
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
server_args.attention_backend = "triton"
if is_hopper_with_cuda_12_3():
if server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None
and server_args.speculative_eagle_topk == 1
):
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
else:
server_args.attention_backend = "triton"
logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
)
Expand All @@ -263,6 +273,16 @@ def model_specific_adjustment(self):
else:
raise ValueError(f"MLA optimization not supported on CPU.")

if (
server_args.attention_backend == "fa3"
and server_args.kv_cache_dtype == "fp8_e5m2"
):
logger.warning(
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
server_args.attention_backend = "triton"

if server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
Expand Down Expand Up @@ -889,9 +909,6 @@ def init_attention_backend(self):
"FlashAttention v3 Backend requires SM>=90. "
"Please use `--attention-backend flashinfer`."
)
logger.warning(
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
)
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,3 +1828,12 @@ def fast_topk(values, topk, dim):
else:
# Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim)


def is_hopper_with_cuda_12_3():
if not is_cuda():
return False
is_hopper = torch.cuda.get_device_capability()[0] == 9
cuda_version = torch.version.cuda.split(".")
is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
return is_hopper and is_cuda_compatible
Loading