@@ -470,26 +470,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
470
470
req_ids_to_add .append (req_id )
471
471
472
472
# Update the states of the running/resumed requests.
473
+ is_last_rank = get_pp_group ().is_last_rank
473
474
req_data = scheduler_output .scheduled_cached_reqs
474
475
for i , req_id in enumerate (req_data .req_ids ):
475
476
req_state = self .requests [req_id ]
476
477
num_computed_tokens = req_data .num_computed_tokens [i ]
477
- new_token_ids = req_data .new_token_ids [i ]
478
478
new_block_ids = req_data .new_block_ids [i ]
479
479
resumed_from_preemption = req_data .resumed_from_preemption [i ]
480
480
481
481
# Update the cached states.
482
482
req_state .num_computed_tokens = num_computed_tokens
483
- # Add the sampled token(s) from the previous step (if any).
484
- # This doesn't include "unverified" tokens like spec decode tokens.
485
- num_new_tokens = (num_computed_tokens + len (new_token_ids ) -
486
- req_state .num_tokens )
487
- if num_new_tokens == 1 :
488
- # Avoid slicing list in most common case.
489
- req_state .output_token_ids .append (new_token_ids [- 1 ])
490
- elif num_new_tokens > 0 :
491
- req_state .output_token_ids .extend (
492
- new_token_ids [- num_new_tokens :])
483
+
484
+ if not is_last_rank :
485
+ # When using PP, the scheduler sends the sampled tokens back,
486
+ # because there's no direct communication between the first-
487
+ # stage worker and the last-stage worker.
488
+ new_token_ids = req_data .new_token_ids [i ]
489
+ # Add the sampled token(s) from the previous step (if any).
490
+ # This doesn't include "unverified" tokens like spec tokens.
491
+ num_new_tokens = (num_computed_tokens + len (new_token_ids ) -
492
+ req_state .num_tokens )
493
+ if num_new_tokens == 1 :
494
+ # Avoid slicing list in most common case.
495
+ req_state .output_token_ids .append (new_token_ids [- 1 ])
496
+ elif num_new_tokens > 0 :
497
+ req_state .output_token_ids .extend (
498
+ new_token_ids [- num_new_tokens :])
499
+
493
500
# Update the block IDs.
494
501
if not resumed_from_preemption :
495
502
# Append the new blocks to the existing block IDs.
@@ -513,22 +520,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
513
520
self .input_batch .num_computed_tokens_cpu [req_index ] = (
514
521
num_computed_tokens )
515
522
self .input_batch .block_table .append_row (new_block_ids , req_index )
516
- # Add new_token_ids to token_ids_cpu.
517
- start_token_index = num_computed_tokens
518
- end_token_index = num_computed_tokens + len (new_token_ids )
519
- self .input_batch .token_ids_cpu [
520
- req_index , start_token_index :end_token_index ] = new_token_ids
521
- self .input_batch .num_tokens_no_spec [req_index ] = end_token_index
522
- # Add spec_token_ids to token_ids_cpu.
523
- spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
524
- req_id , ())
525
- if spec_token_ids :
526
- start_index = end_token_index
527
- end_token_index += len (spec_token_ids )
523
+
524
+ # For the last rank, we don't need to update the token_ids_cpu
525
+ # because the sampled tokens are already cached.
526
+ if not is_last_rank :
527
+ # Add new_token_ids to token_ids_cpu.
528
+ start_token_index = num_computed_tokens
529
+ end_token_index = num_computed_tokens + len (new_token_ids )
528
530
self .input_batch .token_ids_cpu [
529
- req_index , start_index :end_token_index ] = spec_token_ids
530
- # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
531
- self .input_batch .num_tokens [req_index ] = end_token_index
531
+ req_index ,
532
+ start_token_index :end_token_index ] = new_token_ids
533
+ self .input_batch .num_tokens_no_spec [
534
+ req_index ] = end_token_index
535
+ # Add spec_token_ids to token_ids_cpu.
536
+ spec_token_ids = (
537
+ scheduler_output .scheduled_spec_decode_tokens .get (
538
+ req_id , ()))
539
+ if spec_token_ids :
540
+ start_index = end_token_index
541
+ end_token_index += len (spec_token_ids )
542
+ self .input_batch .token_ids_cpu [
543
+ req_index ,
544
+ start_index :end_token_index ] = spec_token_ids
545
+ # NOTE(woosuk): `num_tokens` here may include spec tokens.
546
+ self .input_batch .num_tokens [req_index ] = end_token_index
532
547
533
548
# Check if the batch has changed. If not, we can skip copying the
534
549
# sampling metadata from CPU to GPU.
@@ -1509,6 +1524,30 @@ def execute_model(
1509
1524
for i in discard_sampled_tokens_req_indices :
1510
1525
valid_sampled_token_ids [i ].clear ()
1511
1526
1527
+ # Cache the sampled tokens in the model runner, so that the scheduler
1528
+ # doesn't need to send them back.
1529
+ # NOTE(woosuk): As an exception, when using PP, the scheduler sends
1530
+ # the sampled tokens back, because there's no direct communication
1531
+ # between the first-stage worker and the last-stage worker.
1532
+ for req_idx , sampled_ids in enumerate (valid_sampled_token_ids ):
1533
+ if not sampled_ids :
1534
+ continue
1535
+
1536
+ start_idx = self .input_batch .num_tokens_no_spec [req_idx ]
1537
+ end_idx = start_idx + len (sampled_ids )
1538
+ assert end_idx <= self .max_model_len , (
1539
+ "Sampled token IDs exceed the max model length. "
1540
+ f"Total number of tokens: { end_idx } > max_model_len: "
1541
+ f"{ self .max_model_len } " )
1542
+
1543
+ self .input_batch .token_ids_cpu [req_idx ,
1544
+ start_idx :end_idx ] = sampled_ids
1545
+ self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
1546
+ self .input_batch .num_tokens [req_idx ] = end_idx
1547
+ req_id = self .input_batch .req_ids [req_idx ]
1548
+ req_state = self .requests [req_id ]
1549
+ req_state .output_token_ids .extend (sampled_ids )
1550
+
1512
1551
if not self .speculative_config :
1513
1552
# Speculative decoding is not enabled.
1514
1553
spec_token_ids = None
@@ -1730,17 +1769,14 @@ def propose_ngram_draft_token_ids(
1730
1769
draft_token_ids .append ([])
1731
1770
continue
1732
1771
1733
- # Add sampled_token_ids to token_ids_cpu.
1734
- start_idx = self .input_batch .num_tokens_no_spec [i ]
1735
- end_idx = start_idx + num_sampled_ids
1736
- if end_idx >= self .max_model_len :
1772
+ num_tokens = self .input_batch .num_tokens_no_spec [i ]
1773
+ if num_tokens >= self .max_model_len :
1737
1774
# Skip requests that have already reached the max model length.
1738
1775
draft_token_ids .append ([])
1739
1776
continue
1740
1777
1741
- self .input_batch .token_ids_cpu [i , start_idx :end_idx ] = sampled_ids
1742
1778
drafter_output = self .drafter .propose (
1743
- self .input_batch .token_ids_cpu [i , :end_idx ])
1779
+ self .input_batch .token_ids_cpu [i , :num_tokens ])
1744
1780
if drafter_output is None or len (drafter_output ) == 0 :
1745
1781
draft_token_ids .append ([])
1746
1782
else :
0 commit comments