@@ -484,7 +484,7 @@ def init_memory_pool_and_cache(self):
484
484
self .tree_cache = HiRadixCache (
485
485
req_to_token_pool = self .req_to_token_pool ,
486
486
token_to_kv_pool_allocator = self .token_to_kv_pool_allocator ,
487
- tp_cache_group = self .tp_worker . get_tp_cpu_group () ,
487
+ tp_cache_group = self .tp_cpu_group ,
488
488
page_size = self .page_size ,
489
489
hicache_ratio = server_args .hicache_ratio ,
490
490
)
@@ -553,7 +553,7 @@ def init_disaggregation(self):
553
553
554
554
# The decode requests polling kv cache
555
555
self .disagg_decode_transfer_queue = DecodeTransferQueue (
556
- gloo_group = self .tp_worker . get_attention_tp_cpu_group () ,
556
+ gloo_group = self .attn_tp_cpu_group ,
557
557
req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator ,
558
558
metadata_buffers = metadata_buffers ,
559
559
)
@@ -568,7 +568,7 @@ def init_disaggregation(self):
568
568
scheduler = self ,
569
569
transfer_queue = self .disagg_decode_transfer_queue ,
570
570
tree_cache = self .tree_cache ,
571
- gloo_group = self .tp_worker . get_attention_tp_cpu_group () ,
571
+ gloo_group = self .attn_tp_cpu_group ,
572
572
tp_rank = self .tp_rank ,
573
573
tp_size = self .tp_size ,
574
574
bootstrap_port = self .server_args .disaggregation_bootstrap_port ,
@@ -597,7 +597,7 @@ def init_disaggregation(self):
597
597
tp_rank = self .tp_rank ,
598
598
tp_size = self .tp_size ,
599
599
bootstrap_port = self .server_args .disaggregation_bootstrap_port ,
600
- gloo_group = self .tp_worker . get_attention_tp_cpu_group () ,
600
+ gloo_group = self .attn_tp_cpu_group ,
601
601
transfer_backend = self .transfer_backend ,
602
602
scheduler = self ,
603
603
)
@@ -676,6 +676,11 @@ def event_loop_normal_disagg_prefill(self):
676
676
)
677
677
self .process_prefill_chunk ()
678
678
batch = self .get_new_batch_prefill ()
679
+
680
+ # Handle DP attention
681
+ if self .server_args .enable_dp_attention or self .server_args .enable_sp_layernorm :
682
+ batch , _ = self .prepare_dp_attn_batch (batch )
683
+
679
684
self .cur_batch = batch
680
685
681
686
if batch :
@@ -704,11 +709,18 @@ def event_loop_normal_disagg_decode(self):
704
709
# polling and allocating kv cache
705
710
self .process_decode_queue ()
706
711
batch = self .get_next_disagg_decode_batch_to_run ()
712
+
713
+ # Handle DP attention
714
+ if self .server_args .enable_dp_attention or self .server_args .enable_sp_layernorm :
715
+ batch , do_extend = self .prepare_dp_attn_batch (batch )
716
+ elif batch :
717
+ do_extend = batch .forward_mode .is_extend ()
718
+
707
719
self .cur_batch = batch
708
720
709
721
if batch :
710
722
# Generate fake extend output.
711
- if batch . forward_mode . is_extend () :
723
+ if do_extend :
712
724
# Note: Logprobs should be handled on the prefill engine.
713
725
self .stream_output (
714
726
batch .reqs , [False for _ in range (len (batch .reqs ))]
0 commit comments