Skip to content

Commit b4ae8ea

Browse files
qingquansonghebiao064Fridge003zcnrex
authored andcommitted
Add Eagle Speculative Decoding to FA3 Backend (sgl-project#4951)
Co-authored-by: hebiao064 <[email protected]> Co-authored-by: Baizhou Zhang <[email protected]> Co-authored-by: zcnrex <[email protected]>
1 parent 327e039 commit b4ae8ea

File tree

2 files changed

+210
-27
lines changed

2 files changed

+210
-27
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 197 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def __init__(
4545
self,
4646
model_runner: ModelRunner,
4747
skip_prefill: bool = False,
48+
topk=0,
49+
speculative_num_steps=0,
50+
step_id=0,
4851
):
4952
super().__init__()
5053

@@ -63,6 +66,10 @@ def __init__(
6366
self.use_mla = (
6467
model_runner.model_config.attention_arch == AttentionArch.MLA
6568
) and (not global_server_args_dict["disable_mla"])
69+
self.skip_prefill = skip_prefill
70+
self.topk = topk
71+
self.speculative_num_steps = speculative_num_steps
72+
self.step_id = step_id
6673

6774
def init_forward_metadata(self, forward_batch: ForwardBatch):
6875
"""Initialize forward metadata to cache repetitive calculations."""
@@ -72,37 +79,125 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
7279
# Get sequence information
7380
seqlens_in_batch = forward_batch.seq_lens
7481
# Precompute int32 version of sequence lengths
75-
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
7682
batch_size = len(seqlens_in_batch)
7783
device = seqlens_in_batch.device
78-
metadata.cu_seqlens_k = torch.nn.functional.pad(
79-
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
80-
)
81-
# Precompute maximum sequence length
82-
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
83-
# Precompute page table
84-
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
85-
forward_batch.req_pool_indices, : metadata.max_seq_len_k
86-
]
87-
88-
# Precompute strided indices
89-
# [0, page_size, 2 * page_size, ...]
90-
if self.page_size > 1:
91-
self.strided_indices = torch.arange(
92-
0, metadata.page_table.shape[1], self.page_size, device=self.device
93-
)
94-
metadata.page_table = (
95-
metadata.page_table[:, self.strided_indices] // self.page_size
96-
)
9784

9885
if forward_batch.forward_mode == ForwardMode.DECODE:
99-
# Precompute cumulative sequence lengths
86+
if self.skip_prefill:
87+
metadata.cu_seqlens_q = torch.arange(
88+
0, batch_size * self.topk + 1, dtype=torch.int32, device=device
89+
)
90+
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
91+
metadata.cache_seqlens_int32 = (
92+
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
93+
)
94+
metadata.cu_seqlens_k = torch.nn.functional.pad(
95+
torch.cumsum(
96+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
97+
),
98+
(1, 0),
99+
)
100+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
101+
self.step_id + 1
102+
)
103+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
104+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
105+
]
106+
metadata.page_table = metadata.page_table.repeat_interleave(
107+
self.topk, dim=0
108+
)
109+
cache_loc = forward_batch.out_cache_loc.view(
110+
self.speculative_num_steps, -1
111+
).T
112+
# Calculate page table indices and cache location indices to update the page table.
113+
batch_indices = torch.arange(
114+
batch_size, device=device
115+
).repeat_interleave(self.topk * (self.step_id + 1))
116+
topk_indices = torch.arange(self.topk, device=device).repeat(
117+
batch_size * (self.step_id + 1)
118+
)
119+
row_indices = batch_indices * self.topk + topk_indices
120+
121+
page_table_col_base_indices = seqlens_in_batch.unsqueeze(
122+
1
123+
) + torch.arange(self.step_id + 1, device=device)
124+
page_table_col_indices = page_table_col_base_indices.view(-1).repeat(
125+
self.topk
126+
)
127+
128+
cache_loc_col_indices = torch.arange(
129+
self.step_id + 1, device=device, dtype=torch.int32
130+
).repeat(batch_size * self.topk)
131+
132+
metadata.page_table[row_indices, page_table_col_indices] = cache_loc[
133+
row_indices, cache_loc_col_indices
134+
].to(torch.int32)
135+
else:
136+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
137+
metadata.cu_seqlens_k = torch.nn.functional.pad(
138+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
139+
)
140+
# Precompute maximum sequence length
141+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
142+
# Precompute page table
143+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
144+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
145+
]
146+
metadata.cu_seqlens_q = torch.arange(
147+
0, batch_size + 1, dtype=torch.int32, device=device
148+
)
149+
elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY:
150+
draft_token_num = forward_batch.spec_info.draft_token_num
151+
100152
metadata.cu_seqlens_q = torch.arange(
101-
0, batch_size + 1, dtype=torch.int32, device=device
153+
0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device
154+
)
155+
156+
aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
157+
metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave(
158+
forward_batch.spec_info.draft_token_num
159+
)
160+
metadata.cu_seqlens_k = torch.nn.functional.pad(
161+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
162+
(1, 0),
102163
)
164+
metadata.max_seq_len_k = (
165+
forward_batch.seq_lens_cpu.max().item() + draft_token_num
166+
)
167+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
168+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
169+
].repeat_interleave(draft_token_num, dim=0)
170+
aug_cum_len = torch.nn.functional.pad(
171+
torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)
172+
)
173+
for idx, single_seq_len in enumerate(aug_seq_lens):
174+
metadata.page_table[
175+
idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len
176+
] *= forward_batch.spec_info.custom_mask[
177+
aug_cum_len[idx]
178+
* draft_token_num : aug_cum_len[idx + 1]
179+
* draft_token_num
180+
].view(
181+
draft_token_num, -1
182+
)
183+
184+
metadata.max_seq_len_q = 1
103185
else:
186+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
187+
metadata.cu_seqlens_k = torch.nn.functional.pad(
188+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
189+
)
190+
# Precompute maximum sequence length
191+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
192+
# Precompute page table
193+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
194+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
195+
]
104196
# Precompute cumulative sequence lengths
105-
if any(forward_batch.extend_prefix_lens_cpu):
197+
if (
198+
any(forward_batch.extend_prefix_lens_cpu)
199+
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
200+
):
106201
extend_seq_lens = forward_batch.extend_seq_lens
107202
metadata.cu_seqlens_q = torch.nn.functional.pad(
108203
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
@@ -111,6 +206,16 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
111206
else:
112207
metadata.cu_seqlens_q = metadata.cu_seqlens_k
113208
metadata.max_seq_len_q = metadata.max_seq_len_k
209+
210+
# Precompute strided indices
211+
# [0, page_size, 2 * page_size, ...]
212+
if self.page_size > 1:
213+
self.strided_indices = torch.arange(
214+
0, metadata.page_table.shape[1], self.page_size, device=self.device
215+
)
216+
metadata.page_table = (
217+
metadata.page_table[:, self.strided_indices] // self.page_size
218+
)
114219
self.forward_metadata = metadata
115220

116221
def forward_extend(
@@ -281,8 +386,6 @@ def forward_decode(
281386

282387
# Pre-reshape query tensor
283388
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
284-
285-
# Run attention with precomputed values
286389
o = flash_attn_with_kvcache(
287390
q=q_reshaped,
288391
k_cache=key_cache,
@@ -346,7 +449,11 @@ def init_cuda_graph_state(self, max_bs: int):
346449
This creates fixed-size tensors that will be reused during CUDA graph replay
347450
to avoid memory allocations.
348451
"""
349-
# Initialize fixed size tensors for decode operations
452+
if self.speculative_num_steps > 0:
453+
raise NotImplementedError(
454+
"FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!"
455+
)
456+
350457
self.decode_cuda_graph_metadata = {
351458
# Page table for token mapping (batch_size, max_context_len)
352459
"page_table": torch.zeros(
@@ -385,7 +492,7 @@ def init_forward_metadata_capture_cuda_graph(
385492
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
386493
req_pool_indices, :
387494
]
388-
if forward_mode == ForwardMode.DECODE:
495+
if forward_mode.is_cuda_graph():
389496
# Precompute cumulative sequence lengths
390497
metadata.cu_seqlens_q = torch.arange(
391498
0, batch_size + 1, dtype=torch.int32, device=device
@@ -432,3 +539,66 @@ def init_forward_metadata_replay_cuda_graph(
432539
def get_cuda_graph_seq_len_fill_value(self):
433540
"""Get the fill value for sequence length in CUDA graph."""
434541
return 0
542+
543+
544+
class FlashAttentionMultiStepBackend:
545+
546+
def __init__(
547+
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
548+
):
549+
self.model_runner = model_runner
550+
self.topk = topk
551+
self.speculative_num_steps = speculative_num_steps
552+
553+
self.attn_backends = []
554+
for i in range(self.speculative_num_steps):
555+
self.attn_backends.append(
556+
FlashAttentionBackend(
557+
model_runner,
558+
skip_prefill=True,
559+
topk=self.topk,
560+
speculative_num_steps=self.speculative_num_steps,
561+
step_id=i,
562+
)
563+
)
564+
565+
def init_forward_metadata(self, forward_batch: ForwardBatch):
566+
for i in range(self.speculative_num_steps - 1):
567+
self.attn_backends[i].init_forward_metadata(forward_batch)
568+
569+
def init_cuda_graph_state(self, max_bs: int):
570+
for i in range(self.speculative_num_steps):
571+
self.attn_backends[i].init_cuda_graph_state(max_bs)
572+
573+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
574+
assert forward_batch.spec_info is not None
575+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
576+
577+
for i in range(self.speculative_num_steps - 1):
578+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
579+
forward_batch.batch_size,
580+
forward_batch.batch_size * self.topk,
581+
forward_batch.req_pool_indices,
582+
forward_batch.seq_lens,
583+
encoder_lens=None,
584+
forward_mode=ForwardMode.DECODE,
585+
spec_info=forward_batch.spec_info,
586+
)
587+
588+
def init_forward_metadata_replay_cuda_graph(
589+
self, forward_batch: ForwardBatch, bs: int
590+
):
591+
assert forward_batch.spec_info is not None
592+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
593+
594+
for i in range(self.speculative_num_steps - 1):
595+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
596+
bs,
597+
forward_batch.req_pool_indices,
598+
forward_batch.seq_lens,
599+
forward_batch.seq_lens_sum,
600+
encoder_lens=None,
601+
forward_mode=ForwardMode.DECODE,
602+
spec_info=forward_batch.spec_info,
603+
seq_lens_cpu=forward_batch.seq_lens_cpu,
604+
)

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,19 @@ def init_attention_backend(self):
184184
self.draft_extend_attn_backend = None
185185
self.padded_static_len = self.speculative_num_steps + 1
186186
self.has_prefill_wrapper_verify = True
187+
elif self.server_args.attention_backend == "fa3":
188+
from sglang.srt.layers.attention.flashattention_backend import (
189+
FlashAttentionMultiStepBackend,
190+
)
191+
192+
self.draft_attn_backend = FlashAttentionMultiStepBackend(
193+
self.draft_model_runner,
194+
self.topk,
195+
self.speculative_num_steps,
196+
)
197+
self.draft_extend_attn_backend = None
198+
self.padded_static_len = self.speculative_num_steps + 1
199+
self.has_prefill_wrapper_verify = False
187200
else:
188201
raise ValueError(
189202
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"

0 commit comments

Comments
 (0)