@@ -411,7 +411,13 @@ def init_cuda_graph_state(self, max_bs: int):
411
411
to avoid memory allocations.
412
412
"""
413
413
self .decode_cuda_graph_metadata = {
414
- # Page table for token mapping (batch_size, max_context_len)
414
+ "cache_seqlens" : torch .zeros (max_bs , dtype = torch .int32 , device = self .device ),
415
+ "cu_seqlens_q" : torch .arange (
416
+ 0 , max_bs + 1 , dtype = torch .int32 , device = self .device
417
+ ),
418
+ "cu_seqlens_k" : torch .zeros (
419
+ max_bs + 1 , dtype = torch .int32 , device = self .device
420
+ ),
415
421
"page_table" : torch .zeros (
416
422
max_bs ,
417
423
(self .max_context_len + self .page_size - 1 ) // self .page_size ,
@@ -427,13 +433,6 @@ def init_cuda_graph_state(self, max_bs: int):
427
433
"strided_indices" : torch .arange (
428
434
0 , self .max_context_len , self .page_size , device = self .device
429
435
),
430
- "cache_seqlens" : torch .zeros (max_bs , dtype = torch .int32 , device = self .device ),
431
- "cu_seqlens_q" : torch .arange (
432
- 0 , max_bs + 128 , dtype = torch .int32 , device = self .device
433
- ),
434
- "cu_seqlens_k" : torch .zeros (
435
- max_bs + 128 , dtype = torch .int32 , device = self .device
436
- ),
437
436
}
438
437
439
438
self .target_verify_metadata = {
@@ -471,26 +470,21 @@ def init_forward_metadata_capture_cuda_graph(
471
470
if forward_mode .is_decode ():
472
471
if spec_info is not None :
473
472
# Draft Decode
474
- metadata .cu_seqlens_q = torch .arange (
475
- 0 , bs + 1 , dtype = torch .int32 , device = device
476
- )
477
473
metadata .cache_seqlens_int32 = self .decode_cuda_graph_metadata [
478
474
"cache_seqlens"
479
475
][:bs ]
480
-
476
+ metadata .max_seq_len_k = seq_lens .max ().item () + (
477
+ self .speculative_step_id + 1
478
+ )
481
479
metadata .cu_seqlens_q = self .decode_cuda_graph_metadata ["cu_seqlens_q" ][
482
480
: bs + 1
483
481
]
484
-
485
482
metadata .cu_seqlens_k = torch .nn .functional .pad (
486
483
torch .cumsum (
487
484
metadata .cache_seqlens_int32 , dim = 0 , dtype = torch .int32
488
485
),
489
486
(1 , 0 ),
490
487
)
491
- metadata .max_seq_len_k = seq_lens .max ().item () + (
492
- self .speculative_step_id + 1
493
- )
494
488
metadata .page_table = self .decode_cuda_graph_metadata [
495
489
"page_table_draft_decode"
496
490
][req_pool_indices , :]
@@ -560,26 +554,21 @@ def init_forward_metadata_replay_cuda_graph(
560
554
out_cache_loc : torch .Tensor = None ,
561
555
):
562
556
# """Initialize forward metadata for replaying CUDA graph."""
563
- device = seq_lens .device
564
557
seq_lens = seq_lens [:bs ]
565
- req_pool_indices = req_pool_indices [:bs ]
566
558
seq_lens_cpu = seq_lens_cpu [:bs ]
559
+ req_pool_indices = req_pool_indices [:bs ]
567
560
if forward_mode .is_decode ():
568
561
metadata = self .decode_cuda_graph_metadata [bs ]
569
562
570
563
if spec_info is not None :
571
564
# Draft Decode
572
- max_len = seq_lens_cpu .max ().item ()
573
- metadata .max_seq_len_k = max_len + (self .speculative_step_id + 1 )
574
-
575
565
metadata .cache_seqlens_int32 .copy_ (
576
566
(seq_lens + (self .speculative_step_id + 1 )).to (torch .int32 )
577
567
)
578
568
579
569
metadata .max_seq_len_k = seq_lens_cpu .max ().item () + (
580
570
self .speculative_step_id + 1
581
571
)
582
-
583
572
metadata .cu_seqlens_k .copy_ (
584
573
torch .nn .functional .pad (
585
574
torch .cumsum (
0 commit comments