Skip to content

Commit ff9b8bf

Browse files
zhyncsyundai424hebiao064
authored andcommitted
feat: use fa3 mla by default on hopper (sgl-project#5210)
Co-authored-by: yundai424 <[email protected]> Co-authored-by: hebiao064 <[email protected]>
1 parent 31ac7d2 commit ff9b8bf

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
325325
batch_size = len(seqlens_in_batch)
326326
device = seqlens_in_batch.device
327327

328-
if forward_batch.forward_mode.is_decode():
328+
if forward_batch.forward_mode.is_decode_or_idle():
329329
# Draft Decode
330330
if forward_batch.spec_info is not None:
331331
metadata.cache_seqlens_int32 = (
@@ -527,7 +527,9 @@ def forward_extend(
527527
else (-1, -1)
528528
)
529529
k_descale, v_descale = None, None
530-
if self.kv_cache_dtype_str != "auto":
530+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
531+
# has corresponding quantization method so that layer.k_scale is not None
532+
if self.kv_cache_dtype_str != "auto" and layer.k_scale is not None:
531533
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
532534
k_descale = layer.k_scale.expand(descale_shape)
533535
v_descale = layer.v_scale.expand(descale_shape)
@@ -670,10 +672,13 @@ def forward_decode(
670672
causal = not layer.is_cross_attention
671673

672674
k_descale, v_descale = None, None
675+
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
676+
# has corresponding quantization method so that layer.k_scale is not None
673677
if self.kv_cache_dtype_str != "auto":
674-
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
675-
k_descale = layer.k_scale.expand(descale_shape)
676-
v_descale = layer.v_scale.expand(descale_shape)
678+
if layer.k_scale is not None:
679+
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
680+
k_descale = layer.k_scale.expand(descale_shape)
681+
v_descale = layer.v_scale.expand(descale_shape)
677682
q = q.to(self.kv_cache_dtype)
678683

679684
if not self.use_mla:
@@ -834,7 +839,7 @@ def init_forward_metadata_capture_cuda_graph(
834839
"""Initialize forward metadata for capturing CUDA graph."""
835840
metadata = FlashAttentionMetadata()
836841
device = seq_lens.device
837-
if forward_mode.is_decode():
842+
if forward_mode.is_decode_or_idle():
838843
if spec_info is not None:
839844
# Draft Decode
840845
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
@@ -937,7 +942,7 @@ def init_forward_metadata_replay_cuda_graph(
937942
seq_lens = seq_lens[:bs]
938943
seq_lens_cpu = seq_lens_cpu[:bs]
939944
req_pool_indices = req_pool_indices[:bs]
940-
if forward_mode.is_decode():
945+
if forward_mode.is_decode_or_idle():
941946
metadata = self.decode_cuda_graph_metadata[bs]
942947

943948
if spec_info is not None:

python/sglang/srt/model_executor/model_runner.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
is_cuda,
8181
is_flashinfer_available,
8282
is_hip,
83+
is_hopper_with_cuda_12_3,
8384
monkey_patch_p2p_access_check,
8485
monkey_patch_vllm_gguf_config,
8586
set_cpu_offload_max_bytes,
@@ -245,7 +246,16 @@ def model_specific_adjustment(self):
245246
"flashinfer" if is_flashinfer_available() else "triton"
246247
)
247248
else:
248-
server_args.attention_backend = "triton"
249+
if is_hopper_with_cuda_12_3():
250+
if server_args.speculative_eagle_topk is None or (
251+
server_args.speculative_eagle_topk is not None
252+
and server_args.speculative_eagle_topk == 1
253+
):
254+
server_args.attention_backend = "fa3"
255+
else:
256+
server_args.attention_backend = "triton"
257+
else:
258+
server_args.attention_backend = "triton"
249259
logger.info(
250260
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
251261
)
@@ -263,6 +273,16 @@ def model_specific_adjustment(self):
263273
else:
264274
raise ValueError(f"MLA optimization not supported on CPU.")
265275

276+
if (
277+
server_args.attention_backend == "fa3"
278+
and server_args.kv_cache_dtype == "fp8_e5m2"
279+
):
280+
logger.warning(
281+
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
282+
"Setting attention backend to triton."
283+
)
284+
server_args.attention_backend = "triton"
285+
266286
if server_args.enable_double_sparsity:
267287
logger.info(
268288
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
@@ -889,9 +909,6 @@ def init_attention_backend(self):
889909
"FlashAttention v3 Backend requires SM>=90. "
890910
"Please use `--attention-backend flashinfer`."
891911
)
892-
logger.warning(
893-
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
894-
)
895912
from sglang.srt.layers.attention.flashattention_backend import (
896913
FlashAttentionBackend,
897914
)

python/sglang/srt/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,3 +1828,12 @@ def fast_topk(values, topk, dim):
18281828
else:
18291829
# Use topk for efficiency with larger k values
18301830
return torch.topk(values, topk, dim=dim)
1831+
1832+
1833+
def is_hopper_with_cuda_12_3():
1834+
if not is_cuda():
1835+
return False
1836+
is_hopper = torch.cuda.get_device_capability()[0] == 9
1837+
cuda_version = torch.version.cuda.split(".")
1838+
is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
1839+
return is_hopper and is_cuda_compatible

0 commit comments

Comments
 (0)