Skip to content

Commit 6c45067

Browse files
qingquansonghebiao064yubofredwang
authored andcommitted
Add Speculative Decoding Eagle3 topk > 1 (sgl-project#5318)
Co-authored-by: Stefan He <[email protected]> Co-authored-by: Yubo Wang <[email protected]>
1 parent d4772c0 commit 6c45067

File tree

6 files changed

+861
-156
lines changed

6 files changed

+861
-156
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 781 additions & 150 deletions
Large diffs are not rendered by default.

python/sglang/srt/model_executor/model_runner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,16 @@ def model_specific_adjustment(self):
221221
server_args = self.server_args
222222

223223
if server_args.attention_backend is None:
224-
# By default, use flashinfer for non-mla attention and triton for mla attention
224+
"""
225+
We auto select the fastest attention backend according to the current offering
226+
1. Models with MHA Architecture (e.g: Llama, QWen)
227+
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
228+
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
229+
2. Models with MLA Architecture and using FA3
230+
2.1 We will use FA3 backend on hopper.
231+
2.2 Otherwise, we will use triton backend.
232+
"""
233+
225234
if not self.use_mla_backend:
226235
if (
227236
is_hopper_with_cuda_12_3()
@@ -234,9 +243,7 @@ def model_specific_adjustment(self):
234243
"flashinfer" if is_flashinfer_available() else "triton"
235244
)
236245
else:
237-
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
238-
server_args
239-
):
246+
if is_hopper_with_cuda_12_3():
240247
server_args.attention_backend = "fa3"
241248
else:
242249
server_args.attention_backend = "triton"

python/sglang/srt/server_args.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,18 @@ def __post_init__(self):
359359

360360
if self.page_size > 1 and self.speculative_eagle_topk > 1:
361361
self.speculative_eagle_topk = 1
362-
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
362+
logger.info(
363+
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
364+
)
365+
366+
if (
367+
self.speculative_eagle_topk == 1
368+
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
369+
):
370+
logger.info(
371+
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
372+
)
373+
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
363374

364375
# The token generated from the verify step is counted.
365376
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.

python/sglang/srt/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
19091909
return server_args.page_size == 1
19101910

19111911

1912+
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
1913+
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
19121914
def is_no_spec_infer_or_topk_one(server_args):
19131915
return server_args.speculative_eagle_topk is None or (
19141916
server_args.speculative_eagle_topk is not None

test/srt/run_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class TestFile:
2929
TestFile("test_chunked_prefill.py", 336),
3030
TestFile("test_eagle_infer.py", 500),
3131
TestFile("test_ebnf_constrained.py"),
32-
TestFile("test_fa3.py", 5),
32+
TestFile("test_fa3.py", 200),
3333
TestFile("test_fp8_kernel.py", 8),
3434
TestFile("test_embedding_openai_server.py", 36),
3535
TestFile("test_hidden_states.py", 55),

test/srt/test_fa3.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,60 @@ def test_gsm8k(self):
173173
self.assertGreater(avg_spec_accept_length, 1.5)
174174

175175

176+
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
177+
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
178+
179+
model = "meta-llama/Llama-3.1-8B-Instruct"
180+
181+
@classmethod
182+
def get_server_args(cls):
183+
args = super().get_server_args()
184+
args.extend(
185+
[
186+
"--cuda-graph-max-bs",
187+
"2",
188+
"--speculative-algorithm",
189+
"EAGLE3",
190+
"--speculative-draft",
191+
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
192+
"--speculative-num-steps",
193+
"5",
194+
"--speculative-eagle-topk",
195+
"4",
196+
"--speculative-num-draft-tokens",
197+
"8",
198+
"--dtype",
199+
"float16",
200+
]
201+
)
202+
return args
203+
204+
def test_gsm8k(self):
205+
"""
206+
Override the test_gsm8k to further test for average speculative accept length.
207+
"""
208+
requests.get(self.base_url + "/flush_cache")
209+
210+
args = SimpleNamespace(
211+
num_shots=5,
212+
data_path=DATA_PATH,
213+
num_questions=200,
214+
max_new_tokens=512,
215+
parallel=128,
216+
host="http://127.0.0.1",
217+
port=int(self.base_url.split(":")[-1]),
218+
)
219+
metrics = run_eval_few_shot_gsm8k(args)
220+
print(metrics)
221+
222+
self.assertGreater(metrics["accuracy"], 0.60)
223+
224+
server_info = requests.get(self.base_url + "/get_server_info")
225+
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
226+
print(f"{avg_spec_accept_length=}")
227+
self.assertGreater(avg_spec_accept_length, 1.8)
228+
229+
176230
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
177231
"""Test FlashAttention3 with speculative decode enabled."""
178232

0 commit comments

Comments
 (0)