Skip to content

Commit 7f280d6

Browse files
authored
[Optimization] Cache sampled token ids in model runner (vllm-project#20291)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 02cabff commit 7f280d6

File tree

5 files changed

+91
-45
lines changed

5 files changed

+91
-45
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
172172
req_state.block_ids[0]).all()
173173

174174

175-
def test_update_states_new_request(model_runner):
175+
def test_update_states_new_request(model_runner, dist_init):
176176
req_id = "req_0"
177177

178178
# new req
@@ -186,7 +186,7 @@ def test_update_states_new_request(model_runner):
186186
assert _is_req_state_block_table_match(model_runner, req_id)
187187

188188

189-
def test_update_states_request_finished(model_runner):
189+
def test_update_states_request_finished(model_runner, dist_init):
190190
req_id = "req_0"
191191

192192
# new req
@@ -218,7 +218,7 @@ def test_update_states_request_finished(model_runner):
218218
assert not _is_req_scheduled(model_runner, req_id)
219219

220220

221-
def test_update_states_request_resumed(model_runner):
221+
def test_update_states_request_resumed(model_runner, dist_init):
222222
req_id = "req_0"
223223

224224
# new req
@@ -278,7 +278,7 @@ def test_update_states_request_resumed(model_runner):
278278
assert _is_req_state_block_table_match(model_runner, req_id)
279279

280280

281-
def test_get_nans_in_logits(model_runner):
281+
def test_get_nans_in_logits(model_runner, dist_init):
282282
req_ids = ("req_0", "req_1")
283283

284284
scheduler_output = _schedule_new_request(*req_ids)
@@ -326,7 +326,7 @@ def test_get_nans_in_logits(model_runner):
326326
assert result == {'req_0': 2, 'req_1': 0}
327327

328328

329-
def test_update_states_no_changes(model_runner):
329+
def test_update_states_no_changes(model_runner, dist_init):
330330
req_id = "req_0"
331331

332332
# new req
@@ -359,7 +359,7 @@ def test_update_states_no_changes(model_runner):
359359
assert _is_req_state_block_table_match(model_runner, req_id)
360360

361361

362-
def test_update_states_request_unscheduled(model_runner):
362+
def test_update_states_request_unscheduled(model_runner, dist_init):
363363
req_ids = ("req_0", "req_1")
364364

365365
# new reqs

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def build_connector_meta(
307307
cached_reqs = scheduler_output.scheduled_cached_reqs
308308
for i, req_id in enumerate(cached_reqs.req_ids):
309309
num_computed_tokens = cached_reqs.num_computed_tokens[i]
310-
new_token_ids = cached_reqs.new_token_ids[i]
310+
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
311311
new_block_ids = cached_reqs.new_block_ids[i]
312312
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
313313

@@ -320,7 +320,7 @@ def build_connector_meta(
320320
# list of token ids (only new tokens). So we look it
321321
# up in the actual request object.
322322
request = self._requests_need_load[req_id]
323-
total_tokens = (len(new_token_ids) + num_computed_tokens)
323+
total_tokens = num_computed_tokens + num_new_tokens
324324
token_ids = request.all_token_ids[:total_tokens]
325325

326326
# NOTE(rob): For resumed req, new_block_ids is all

vllm/v1/core/sched/output.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ class CachedRequestData:
8888
# the request's block IDs. If True, new_block_ids will be used as the
8989
# request's block IDs instead of appending to the existing block IDs.
9090
resumed_from_preemption: list[bool]
91+
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
92+
# When PP is not used, new_token_ids will be empty.
9193
new_token_ids: list[list[int]]
9294
new_block_ids: list[tuple[list[int], ...]]
9395
num_computed_tokens: list[int]

vllm/v1/core/sched/scheduler.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
self.lora_config = vllm_config.lora_config
5656
self.kv_cache_config = kv_cache_config
5757
self.kv_events_config = vllm_config.kv_events_config
58+
self.parallel_config = vllm_config.parallel_config
5859
self.log_stats = log_stats
5960
self.structured_output_manager = structured_output_manager
6061

@@ -87,7 +88,7 @@ def __init__(
8788

8889
self.kv_event_publisher = EventPublisherFactory.create(
8990
self.kv_events_config,
90-
vllm_config.parallel_config.data_parallel_rank,
91+
self.parallel_config.data_parallel_rank,
9192
)
9293

9394
num_gpu_blocks = self.cache_config.num_gpu_blocks
@@ -159,6 +160,7 @@ def __init__(
159160
log_stats=self.log_stats,
160161
enable_kv_cache_events=self.enable_kv_cache_events,
161162
)
163+
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
162164

163165
def schedule(self) -> SchedulerOutput:
164166
# NOTE(woosuk) on the scheduling algorithm:
@@ -214,7 +216,7 @@ def schedule(self) -> SchedulerOutput:
214216
# This is necessary when using spec decoding.
215217
num_new_tokens = min(
216218
num_new_tokens,
217-
self.max_model_len - request.num_computed_tokens)
219+
self.max_model_len - 1 - request.num_computed_tokens)
218220

219221
# Schedule encoder inputs.
220222
encoder_inputs_to_schedule = None
@@ -624,9 +626,15 @@ def _make_cached_request_data(
624626
req_ids.append(req_id)
625627
num_tokens = (num_scheduled_tokens[req_id] -
626628
len(spec_decode_tokens.get(req_id, ())))
627-
token_ids = req.all_token_ids[req.num_computed_tokens:req.
628-
num_computed_tokens + num_tokens]
629-
new_token_ids.append(token_ids)
629+
if self.use_pp:
630+
# When using PP, the scheduler sends the sampled tokens back,
631+
# because there's no direct communication between the first-
632+
# stage worker and the last-stage worker. Otherwise, we don't
633+
# need to send the sampled tokens back because the model runner
634+
# will cache them.
635+
token_ids = req.all_token_ids[req.num_computed_tokens:req.
636+
num_computed_tokens + num_tokens]
637+
new_token_ids.append(token_ids)
630638
new_block_ids.append(req_to_new_block_ids[req_id])
631639
num_computed_tokens.append(req.num_computed_tokens)
632640
# Because resumed_reqs is usually empty, it is more efficient to do

vllm/v1/worker/gpu_model_runner.py

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -470,26 +470,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
470470
req_ids_to_add.append(req_id)
471471

472472
# Update the states of the running/resumed requests.
473+
is_last_rank = get_pp_group().is_last_rank
473474
req_data = scheduler_output.scheduled_cached_reqs
474475
for i, req_id in enumerate(req_data.req_ids):
475476
req_state = self.requests[req_id]
476477
num_computed_tokens = req_data.num_computed_tokens[i]
477-
new_token_ids = req_data.new_token_ids[i]
478478
new_block_ids = req_data.new_block_ids[i]
479479
resumed_from_preemption = req_data.resumed_from_preemption[i]
480480

481481
# Update the cached states.
482482
req_state.num_computed_tokens = num_computed_tokens
483-
# Add the sampled token(s) from the previous step (if any).
484-
# This doesn't include "unverified" tokens like spec decode tokens.
485-
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
486-
req_state.num_tokens)
487-
if num_new_tokens == 1:
488-
# Avoid slicing list in most common case.
489-
req_state.output_token_ids.append(new_token_ids[-1])
490-
elif num_new_tokens > 0:
491-
req_state.output_token_ids.extend(
492-
new_token_ids[-num_new_tokens:])
483+
484+
if not is_last_rank:
485+
# When using PP, the scheduler sends the sampled tokens back,
486+
# because there's no direct communication between the first-
487+
# stage worker and the last-stage worker.
488+
new_token_ids = req_data.new_token_ids[i]
489+
# Add the sampled token(s) from the previous step (if any).
490+
# This doesn't include "unverified" tokens like spec tokens.
491+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
492+
req_state.num_tokens)
493+
if num_new_tokens == 1:
494+
# Avoid slicing list in most common case.
495+
req_state.output_token_ids.append(new_token_ids[-1])
496+
elif num_new_tokens > 0:
497+
req_state.output_token_ids.extend(
498+
new_token_ids[-num_new_tokens:])
499+
493500
# Update the block IDs.
494501
if not resumed_from_preemption:
495502
# Append the new blocks to the existing block IDs.
@@ -513,22 +520,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
513520
self.input_batch.num_computed_tokens_cpu[req_index] = (
514521
num_computed_tokens)
515522
self.input_batch.block_table.append_row(new_block_ids, req_index)
516-
# Add new_token_ids to token_ids_cpu.
517-
start_token_index = num_computed_tokens
518-
end_token_index = num_computed_tokens + len(new_token_ids)
519-
self.input_batch.token_ids_cpu[
520-
req_index, start_token_index:end_token_index] = new_token_ids
521-
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
522-
# Add spec_token_ids to token_ids_cpu.
523-
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
524-
req_id, ())
525-
if spec_token_ids:
526-
start_index = end_token_index
527-
end_token_index += len(spec_token_ids)
523+
524+
# For the last rank, we don't need to update the token_ids_cpu
525+
# because the sampled tokens are already cached.
526+
if not is_last_rank:
527+
# Add new_token_ids to token_ids_cpu.
528+
start_token_index = num_computed_tokens
529+
end_token_index = num_computed_tokens + len(new_token_ids)
528530
self.input_batch.token_ids_cpu[
529-
req_index, start_index:end_token_index] = spec_token_ids
530-
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
531-
self.input_batch.num_tokens[req_index] = end_token_index
531+
req_index,
532+
start_token_index:end_token_index] = new_token_ids
533+
self.input_batch.num_tokens_no_spec[
534+
req_index] = end_token_index
535+
# Add spec_token_ids to token_ids_cpu.
536+
spec_token_ids = (
537+
scheduler_output.scheduled_spec_decode_tokens.get(
538+
req_id, ()))
539+
if spec_token_ids:
540+
start_index = end_token_index
541+
end_token_index += len(spec_token_ids)
542+
self.input_batch.token_ids_cpu[
543+
req_index,
544+
start_index:end_token_index] = spec_token_ids
545+
# NOTE(woosuk): `num_tokens` here may include spec tokens.
546+
self.input_batch.num_tokens[req_index] = end_token_index
532547

533548
# Check if the batch has changed. If not, we can skip copying the
534549
# sampling metadata from CPU to GPU.
@@ -1509,6 +1524,30 @@ def execute_model(
15091524
for i in discard_sampled_tokens_req_indices:
15101525
valid_sampled_token_ids[i].clear()
15111526

1527+
# Cache the sampled tokens in the model runner, so that the scheduler
1528+
# doesn't need to send them back.
1529+
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
1530+
# the sampled tokens back, because there's no direct communication
1531+
# between the first-stage worker and the last-stage worker.
1532+
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
1533+
if not sampled_ids:
1534+
continue
1535+
1536+
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
1537+
end_idx = start_idx + len(sampled_ids)
1538+
assert end_idx <= self.max_model_len, (
1539+
"Sampled token IDs exceed the max model length. "
1540+
f"Total number of tokens: {end_idx} > max_model_len: "
1541+
f"{self.max_model_len}")
1542+
1543+
self.input_batch.token_ids_cpu[req_idx,
1544+
start_idx:end_idx] = sampled_ids
1545+
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
1546+
self.input_batch.num_tokens[req_idx] = end_idx
1547+
req_id = self.input_batch.req_ids[req_idx]
1548+
req_state = self.requests[req_id]
1549+
req_state.output_token_ids.extend(sampled_ids)
1550+
15121551
if not self.speculative_config:
15131552
# Speculative decoding is not enabled.
15141553
spec_token_ids = None
@@ -1730,17 +1769,14 @@ def propose_ngram_draft_token_ids(
17301769
draft_token_ids.append([])
17311770
continue
17321771

1733-
# Add sampled_token_ids to token_ids_cpu.
1734-
start_idx = self.input_batch.num_tokens_no_spec[i]
1735-
end_idx = start_idx + num_sampled_ids
1736-
if end_idx >= self.max_model_len:
1772+
num_tokens = self.input_batch.num_tokens_no_spec[i]
1773+
if num_tokens >= self.max_model_len:
17371774
# Skip requests that have already reached the max model length.
17381775
draft_token_ids.append([])
17391776
continue
17401777

1741-
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
17421778
drafter_output = self.drafter.propose(
1743-
self.input_batch.token_ids_cpu[i, :end_idx])
1779+
self.input_batch.token_ids_cpu[i, :num_tokens])
17441780
if drafter_output is None or len(drafter_output) == 0:
17451781
draft_token_ids.append([])
17461782
else:

0 commit comments

Comments
 (0)