@@ -503,7 +503,7 @@ def init_memory_pool_and_cache(self):
503
503
self .tree_cache = HiRadixCache (
504
504
req_to_token_pool = self .req_to_token_pool ,
505
505
token_to_kv_pool_allocator = self .token_to_kv_pool_allocator ,
506
- tp_cache_group = self .tp_worker . get_tp_cpu_group () ,
506
+ tp_cache_group = self .tp_cpu_group ,
507
507
page_size = self .page_size ,
508
508
hicache_ratio = server_args .hicache_ratio ,
509
509
)
@@ -572,7 +572,7 @@ def init_disaggregation(self):
572
572
573
573
# The decode requests polling kv cache
574
574
self .disagg_decode_transfer_queue = DecodeTransferQueue (
575
- gloo_group = self .tp_worker . get_attention_tp_cpu_group () ,
575
+ gloo_group = self .attn_tp_cpu_group ,
576
576
req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator ,
577
577
metadata_buffers = metadata_buffers ,
578
578
)
@@ -587,7 +587,7 @@ def init_disaggregation(self):
587
587
scheduler = self ,
588
588
transfer_queue = self .disagg_decode_transfer_queue ,
589
589
tree_cache = self .tree_cache ,
590
- gloo_group = self .tp_worker . get_attention_tp_cpu_group () ,
590
+ gloo_group = self .attn_tp_cpu_group ,
591
591
tp_rank = self .tp_rank ,
592
592
tp_size = self .tp_size ,
593
593
bootstrap_port = self .server_args .disaggregation_bootstrap_port ,
@@ -616,7 +616,7 @@ def init_disaggregation(self):
616
616
tp_rank = self .tp_rank ,
617
617
tp_size = self .tp_size ,
618
618
bootstrap_port = self .server_args .disaggregation_bootstrap_port ,
619
- gloo_group = self .tp_worker . get_attention_tp_cpu_group () ,
619
+ gloo_group = self .attn_tp_cpu_group ,
620
620
transfer_backend = self .transfer_backend ,
621
621
scheduler = self ,
622
622
)
@@ -683,70 +683,6 @@ def event_loop_overlap(self):
683
683
684
684
self .last_batch = batch
685
685
686
- @torch .no_grad ()
687
- def event_loop_normal_disagg_prefill (self ):
688
- """A normal scheduler loop for prefill worker in disaggregation mode."""
689
-
690
- while True :
691
- recv_reqs = self .recv_requests ()
692
- self .process_input_requests (recv_reqs )
693
- self .waiting_queue .extend (
694
- self .disagg_prefill_pending_queue .pop_bootstrapped ()
695
- )
696
- self .process_prefill_chunk ()
697
- batch = self .get_new_batch_prefill ()
698
- self .cur_batch = batch
699
-
700
- if batch :
701
- result = self .run_batch (batch )
702
- self .process_batch_result_disagg_prefill (batch , result )
703
-
704
- if len (self .disagg_prefill_inflight_queue ) > 0 :
705
- self .process_disagg_prefill_inflight_queue ()
706
-
707
- if batch is None and len (self .disagg_prefill_inflight_queue ) == 0 :
708
- self .check_memory ()
709
- self .new_token_ratio = self .init_new_token_ratio
710
-
711
- self .last_batch = batch
712
- # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
713
- # Otherwise, it hangs under high concurrency
714
- self .running_batch .batch_is_full = False
715
-
716
- @torch .no_grad ()
717
- def event_loop_normal_disagg_decode (self ):
718
- """A normal scheduler loop for decode worker in disaggregation mode."""
719
-
720
- while True :
721
- recv_reqs = self .recv_requests ()
722
- self .process_input_requests (recv_reqs )
723
- # polling and allocating kv cache
724
- self .process_decode_queue ()
725
- batch = self .get_next_disagg_decode_batch_to_run ()
726
- self .cur_batch = batch
727
-
728
- if batch :
729
- # Generate fake extend output.
730
- if batch .forward_mode .is_extend ():
731
- # Note: Logprobs should be handled on the prefill engine.
732
- self .stream_output (
733
- batch .reqs , [False for _ in range (len (batch .reqs ))]
734
- )
735
- else :
736
- result = self .run_batch (batch )
737
- self .process_batch_result (batch , result )
738
-
739
- if batch is None and (
740
- len (self .disagg_decode_transfer_queue .queue )
741
- + len (self .disagg_decode_prealloc_queue .queue )
742
- == 0
743
- ):
744
- # When the server is idle, do self-check and re-init some states
745
- self .check_memory ()
746
- self .new_token_ratio = self .init_new_token_ratio
747
-
748
- self .last_batch = batch
749
-
750
686
def recv_requests (self ) -> List [Req ]:
751
687
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
752
688
if self .attn_tp_rank == 0 :
0 commit comments