@@ -710,24 +710,28 @@ def event_loop_normal_disagg_decode(self):
710
710
self .process_decode_queue ()
711
711
batch = self .get_next_disagg_decode_batch_to_run ()
712
712
713
+ extend_batch = None
714
+ if batch and batch .forward_mode .is_extend ():
715
+ extend_batch = batch
716
+ batch = None
717
+
713
718
# Handle DP attention
714
719
if self .server_args .enable_dp_attention or self .server_args .enable_sp_layernorm :
715
720
batch , _ = self .prepare_dp_attn_batch (batch )
716
-
717
- self .cur_batch = batch
721
+
722
+ self .cur_batch = extend_batch if extend_batch else batch
723
+
724
+ # Generate fake extend output.
725
+ if extend_batch :
726
+ # Note: Logprobs should be handled on the prefill engine.
727
+ # FIXME: stream_output
728
+ self .stream_output (
729
+ extend_batch .reqs , False
730
+ )
718
731
719
732
if batch :
720
- # Generate fake extend output.
721
- if batch .forward_mode .is_extend ():
722
- # Note: Logprobs should be handled on the prefill engine.
723
- self .stream_output (
724
- batch .reqs , [False for _ in range (len (batch .reqs ))]
725
- )
726
- result = self .run_batch (batch )
727
- self .process_batch_result (batch , result )
728
- else :
729
- result = self .run_batch (batch )
730
- self .process_batch_result (batch , result )
733
+ result = self .run_batch (batch )
734
+ self .process_batch_result (batch , result )
731
735
732
736
if batch is None and (
733
737
len (self .disagg_decode_transfer_queue .queue )
@@ -738,7 +742,7 @@ def event_loop_normal_disagg_decode(self):
738
742
self .check_memory ()
739
743
self .new_token_ratio = self .init_new_token_ratio
740
744
741
- self .last_batch = batch
745
+ self .last_batch = extend_batch if extend_batch else batch
742
746
743
747
def recv_requests (self ) -> List [Req ]:
744
748
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
0 commit comments