@@ -478,3 +478,185 @@ def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadat
478
478
layer .logit_cap ,
479
479
)
480
480
return o
481
+
482
+
483
+ class DoubleSparseAttnBackend (AttentionBackend ):
484
+ def __init__ (self , model_runner : ModelRunner ):
485
+ # Lazy import to avoid the initialization of cuda context
486
+ from sglang .srt .layers .triton_attention .decode_attention import (
487
+ decode_attention_fwd ,
488
+ )
489
+ from sglang .srt .layers .triton_attention .extend_attention import (
490
+ extend_attention_fwd ,
491
+ )
492
+ from sglang .srt .layers .triton_attention .sparse_decode_attention import (
493
+ decode_sparse_attention_fwd
494
+ )
495
+
496
+ super ().__init__ ()
497
+
498
+ self .decode_attention_fwd = decode_attention_fwd
499
+ self .decode_sparse_attention_fwd = decode_sparse_attention_fwd
500
+ self .extend_attention_fwd = extend_attention_fwd
501
+ self .num_head = model_runner .model_config .num_attention_heads
502
+
503
+ if global_server_args_dict .get ("triton_attention_reduce_in_fp32" , False ):
504
+ self .reduce_dtype = torch .float32
505
+ else :
506
+ self .reduce_dtype = torch .float16
507
+
508
+ self .forward_metadata = None
509
+
510
+ self .cuda_graph_max_seq_len = model_runner .model_config .context_len
511
+
512
+ def init_forward_metadata (
513
+ self , batch : ScheduleBatch , input_metadata : InputMetadata
514
+ ):
515
+ """Init auxiliary variables for triton attention backend."""
516
+
517
+ if input_metadata .forward_mode .is_decode ():
518
+ start_loc = torch .zeros_like (input_metadata .seq_lens , dtype = torch .int32 )
519
+ start_loc [1 :] = torch .cumsum (input_metadata .seq_lens [:- 1 ], dim = 0 )
520
+
521
+ total_num_tokens = torch .sum (input_metadata .seq_lens ).item ()
522
+ attn_logits = torch .empty (
523
+ (self .num_head , total_num_tokens ),
524
+ dtype = self .reduce_dtype ,
525
+ device = "cuda" ,
526
+ )
527
+
528
+ max_seq_len = torch .max (input_metadata .seq_lens ).item ()
529
+ max_extend_len = None
530
+ #NOTE: Align sequence order with req_to_token order
531
+ ds_req_to_token = input_metadata .req_to_token_pool .req_to_token [input_metadata .req_pool_indices ]
532
+ else :
533
+ start_loc = attn_logits = max_seq_len = None
534
+ prefix_lens = torch .tensor (batch .prefix_lens_cpu , device = "cuda" )
535
+ max_extend_len = torch .max (input_metadata .seq_lens - prefix_lens ).item ()
536
+ ds_req_to_token = None
537
+
538
+ self .forward_metadata = start_loc , attn_logits , max_seq_len , max_extend_len , ds_req_to_token
539
+
540
+ def init_cuda_graph_state (self , max_bs : int ):
541
+ #TODO(Andy): Support CUDA graph for double sparse attention
542
+ raise ValueError ("Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph" )
543
+ self .cuda_graph_max_total_num_tokens = max_bs * self .cuda_graph_max_seq_len
544
+
545
+ self .cuda_graph_start_loc = torch .zeros (
546
+ (max_bs ,), dtype = torch .int32 , device = "cuda"
547
+ )
548
+ self .cuda_graph_attn_logits = torch .empty (
549
+ (
550
+ self .num_head ,
551
+ self .cuda_graph_max_total_num_tokens ,
552
+ ),
553
+ dtype = self .reduce_dtype ,
554
+ device = "cuda" ,
555
+ )
556
+
557
+ def init_forward_metadata_capture_cuda_graph (
558
+ self , bs : int , req_pool_indices , seq_lens
559
+ ):
560
+ self .forward_metadata = (
561
+ self .cuda_graph_start_loc ,
562
+ self .cuda_graph_attn_logits ,
563
+ self .cuda_graph_max_seq_len ,
564
+ None ,
565
+ )
566
+
567
+ def init_forward_metadata_replay_cuda_graph (
568
+ self , bs : int , req_pool_indices , seq_lens
569
+ ):
570
+ self .cuda_graph_start_loc .zero_ ()
571
+ self .cuda_graph_start_loc [1 :bs ] = torch .cumsum (seq_lens [: bs - 1 ], dim = 0 )
572
+
573
+ def get_cuda_graph_seq_len_fill_value (self ):
574
+ return 1
575
+
576
+ def forward_extend (self , q , k , v , layer : nn .Module , input_metadata : InputMetadata ):
577
+ # TODO: reuse the buffer across layers
578
+ if layer .qk_head_dim != layer .v_head_dim :
579
+ o = q .new_empty ((q .shape [0 ], layer .tp_q_head_num * layer .v_head_dim ))
580
+ else :
581
+ o = torch .empty_like (q )
582
+
583
+ k_label = torch .gather (k , 2 , input_metadata .sorted_channels [layer .layer_id ].unsqueeze (0 ).expand (k .shape [0 ], - 1 , - 1 ))
584
+
585
+ input_metadata .token_to_kv_pool .set_kv_buffer (
586
+ layer .layer_id , input_metadata .out_cache_loc , k , v , k_label
587
+ )
588
+
589
+ start_loc , attn_logits , max_seq_len , max_extend_len , ds_req_to_token = self .forward_metadata
590
+ self .extend_attention_fwd (
591
+ q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
592
+ k .contiguous (),
593
+ v .contiguous (),
594
+ o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim ),
595
+ input_metadata .token_to_kv_pool .get_key_buffer (layer .layer_id ),
596
+ input_metadata .token_to_kv_pool .get_value_buffer (layer .layer_id ),
597
+ input_metadata .req_to_token_pool .req_to_token ,
598
+ input_metadata .req_pool_indices ,
599
+ input_metadata .seq_lens ,
600
+ input_metadata .extend_seq_lens ,
601
+ input_metadata .extend_start_loc ,
602
+ max_extend_len ,
603
+ layer .scaling ,
604
+ layer .logit_cap ,
605
+ )
606
+ return o
607
+
608
+ def forward_decode (self , q , k , v , layer : nn .Module , input_metadata : InputMetadata ):
609
+ # During torch.compile, there is a bug in rotary_emb that causes the
610
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
611
+ q = q .reshape (- 1 , layer .tp_q_head_num * layer .qk_head_dim )
612
+
613
+ # TODO: reuse the buffer across layers
614
+ if layer .qk_head_dim != layer .v_head_dim :
615
+ o = q .new_empty ((q .shape [0 ], layer .tp_q_head_num * layer .v_head_dim ))
616
+ else :
617
+ o = torch .empty_like (q )
618
+
619
+ start_loc , attn_logits , max_seq_len , max_extend_len , ds_req_to_token = self .forward_metadata
620
+
621
+ k_label = torch .gather (k , 2 , input_metadata .sorted_channels [layer .layer_id ].unsqueeze (0 ).expand (k .shape [0 ], - 1 , - 1 ))
622
+
623
+ input_metadata .token_to_kv_pool .set_kv_buffer (
624
+ layer .layer_id , input_metadata .out_cache_loc , k , v , k_label
625
+ )
626
+
627
+ # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
628
+ # and set a minimum value for sparse_decode
629
+ if max_seq_len < input_metadata .heavy_token_num or max_seq_len < input_metadata .sparse_decode_thresold :
630
+ self .decode_attention_fwd (
631
+ q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
632
+ input_metadata .token_to_kv_pool .get_key_buffer (layer .layer_id ),
633
+ input_metadata .token_to_kv_pool .get_value_buffer (layer .layer_id ),
634
+ o .view (- 1 , layer .tp_q_head_num , layer .v_head_dim ),
635
+ input_metadata .req_to_token_pool .req_to_token ,
636
+ input_metadata .req_pool_indices ,
637
+ start_loc ,
638
+ input_metadata .seq_lens ,
639
+ attn_logits ,
640
+ max_seq_len ,
641
+ layer .scaling ,
642
+ layer .logit_cap ,
643
+ )
644
+ else :
645
+ #TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
646
+ q_label = torch .gather (q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ), 2 , input_metadata .sorted_channels [layer .layer_id ].unsqueeze (0 ).expand (q .shape [0 ], - 1 , - 1 ))
647
+ self .decode_sparse_attention_fwd (
648
+ q .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
649
+ input_metadata .token_to_kv_pool .get_key_buffer (layer .layer_id ),
650
+ input_metadata .token_to_kv_pool .get_value_buffer (layer .layer_id ),
651
+ o .view (- 1 , layer .tp_q_head_num , layer .qk_head_dim ),
652
+ q_label ,
653
+ input_metadata .token_to_kv_pool .get_label_buffer (layer .layer_id ),
654
+ ds_req_to_token ,
655
+ input_metadata .seq_lens ,
656
+ max_seq_len ,
657
+ layer .scaling ,
658
+ layer .logit_cap ,
659
+ input_metadata .heavy_token_num ,
660
+ )
661
+
662
+ return o
0 commit comments