-
-
Notifications
You must be signed in to change notification settings - Fork 8.6k
[V1][Spec Decode] Handle draft tokens beyond max_model_len #16087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
0931329
55b6e1d
f5c3af6
628f2d4
d449ede
af7462b
5a0646d
0c6b211
5959ebb
32e2ad6
2e1f95d
cfe9668
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
||
PADDING_SLOT_ID = -1 | ||
|
||
|
||
class EagleProposer: | ||
|
||
|
@@ -20,7 +22,10 @@ def __init__( | |
self.vllm_config = vllm_config | ||
self.num_speculative_tokens = ( | ||
vllm_config.speculative_config.num_speculative_tokens) | ||
self.max_model_len = vllm_config.model_config.max_model_len | ||
self.block_size = vllm_config.cache_config.block_size | ||
|
||
# Optimization: pre-compute and cache the arange tensor. | ||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, | ||
device=device) | ||
|
||
|
@@ -103,22 +108,52 @@ def propose( | |
# Update the inputs. | ||
input_ids = draft_token_ids_list[-1] | ||
positions += 1 | ||
|
||
# NOTE(woosuk): We should handle the case where the draft model | ||
# generates tokens beyond the max model length. Since it is complex | ||
# to remove such requests from the batch, we keep them in the batch | ||
# but adjust the position ids and slot mappings to avoid the | ||
# out-of-range access during the model execution. The draft tokens | ||
# generated with this adjustment should be ignored. | ||
exceeds_max_model_len = positions >= self.max_model_len | ||
# Mask out the position ids that exceed the max model length. | ||
# Otherwise, we may get out-of-range error in RoPE. | ||
clamped_positions = torch.where(exceeds_max_model_len, 0, | ||
positions) | ||
|
||
# Increment the sequence lengths. | ||
attn_metadata.max_seq_len += 1 | ||
attn_metadata.seq_lens += 1 | ||
# Consider max model length. | ||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, | ||
self.max_model_len) | ||
# For the requests that exceed the max model length, we set the | ||
# sequence length to 1 to minimize their overheads in attention. | ||
attn_metadata.seq_lens = torch.where(exceeds_max_model_len, 1, | ||
attn_metadata.seq_lens) | ||
|
||
# Compute the slot mapping. | ||
block_numbers = positions // self.block_size | ||
block_numbers = clamped_positions // self.block_size | ||
block_ids = block_table.gather(dim=1, | ||
index=block_numbers.view(-1, 1)) | ||
block_ids = block_ids.view(-1) | ||
attn_metadata.slot_mapping = (block_ids * self.block_size + | ||
positions % self.block_size) | ||
clamped_positions % self.block_size) | ||
# Mask out the slot mappings that exceed the max model length. | ||
# Otherwise, the KV cache will be inadvertently updated with the | ||
# padding tokens. | ||
attn_metadata.slot_mapping = torch.where( | ||
exceeds_max_model_len, | ||
PADDING_SLOT_ID, | ||
attn_metadata.slot_mapping, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ekagra-ranjan Thanks for the suggestion. I changed it to |
||
|
||
# Run the model. | ||
with set_forward_context(attn_metadata, self.vllm_config): | ||
hidden_states = self.model( | ||
input_ids=input_ids, | ||
hidden_states=hidden_states, | ||
positions=positions, | ||
positions=clamped_positions, | ||
) | ||
logits = self.model.compute_logits(hidden_states, None) | ||
draft_token_ids, probs = compute_probs_and_sample_next_token( | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please point me to the logic of ignoring such draft tokens in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. The scheduler handles it
vllm/vllm/v1/core/sched/scheduler.py
Lines 188 to 193 in af7462b