Skip to content

Commit e300570

Browse files
Optimized CUDA GRAPH: Draft Decode by removed duplicate code
Co-authored-by: Qingquan Song <[email protected]>
1 parent 21bf3fe commit e300570

File tree

1 file changed

+11
-22
lines changed

1 file changed

+11
-22
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,13 @@ def init_cuda_graph_state(self, max_bs: int):
411411
to avoid memory allocations.
412412
"""
413413
self.decode_cuda_graph_metadata = {
414-
# Page table for token mapping (batch_size, max_context_len)
414+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
415+
"cu_seqlens_q": torch.arange(
416+
0, max_bs + 1, dtype=torch.int32, device=self.device
417+
),
418+
"cu_seqlens_k": torch.zeros(
419+
max_bs + 1, dtype=torch.int32, device=self.device
420+
),
415421
"page_table": torch.zeros(
416422
max_bs,
417423
(self.max_context_len + self.page_size - 1) // self.page_size,
@@ -427,13 +433,6 @@ def init_cuda_graph_state(self, max_bs: int):
427433
"strided_indices": torch.arange(
428434
0, self.max_context_len, self.page_size, device=self.device
429435
),
430-
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
431-
"cu_seqlens_q": torch.arange(
432-
0, max_bs + 128, dtype=torch.int32, device=self.device
433-
),
434-
"cu_seqlens_k": torch.zeros(
435-
max_bs + 128, dtype=torch.int32, device=self.device
436-
),
437436
}
438437

439438
self.target_verify_metadata = {
@@ -471,26 +470,21 @@ def init_forward_metadata_capture_cuda_graph(
471470
if forward_mode.is_decode():
472471
if spec_info is not None:
473472
# Draft Decode
474-
metadata.cu_seqlens_q = torch.arange(
475-
0, bs + 1, dtype=torch.int32, device=device
476-
)
477473
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
478474
"cache_seqlens"
479475
][:bs]
480-
476+
metadata.max_seq_len_k = seq_lens.max().item() + (
477+
self.speculative_step_id + 1
478+
)
481479
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
482480
: bs + 1
483481
]
484-
485482
metadata.cu_seqlens_k = torch.nn.functional.pad(
486483
torch.cumsum(
487484
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
488485
),
489486
(1, 0),
490487
)
491-
metadata.max_seq_len_k = seq_lens.max().item() + (
492-
self.speculative_step_id + 1
493-
)
494488
metadata.page_table = self.decode_cuda_graph_metadata[
495489
"page_table_draft_decode"
496490
][req_pool_indices, :]
@@ -560,26 +554,21 @@ def init_forward_metadata_replay_cuda_graph(
560554
out_cache_loc: torch.Tensor = None,
561555
):
562556
# """Initialize forward metadata for replaying CUDA graph."""
563-
device = seq_lens.device
564557
seq_lens = seq_lens[:bs]
565-
req_pool_indices = req_pool_indices[:bs]
566558
seq_lens_cpu = seq_lens_cpu[:bs]
559+
req_pool_indices = req_pool_indices[:bs]
567560
if forward_mode.is_decode():
568561
metadata = self.decode_cuda_graph_metadata[bs]
569562

570563
if spec_info is not None:
571564
# Draft Decode
572-
max_len = seq_lens_cpu.max().item()
573-
metadata.max_seq_len_k = max_len + (self.speculative_step_id + 1)
574-
575565
metadata.cache_seqlens_int32.copy_(
576566
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
577567
)
578568

579569
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
580570
self.speculative_step_id + 1
581571
)
582-
583572
metadata.cu_seqlens_k.copy_(
584573
torch.nn.functional.pad(
585574
torch.cumsum(

0 commit comments

Comments
 (0)