Skip to content

[V1][Spec Decode] Handle draft tokens beyond max_model_len #16087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 21, 2025
41 changes: 38 additions & 3 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata

PADDING_SLOT_ID = -1


class EagleProposer:

Expand All @@ -20,7 +22,10 @@ def __init__(
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size

# Optimization: pre-compute and cache the arange tensor.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)

Expand Down Expand Up @@ -103,22 +108,52 @@ def propose(
# Update the inputs.
input_ids = draft_token_ids_list[-1]
positions += 1

# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
Comment on lines +123 to +124
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please point me to the logic of ignoring such draft tokens in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. The scheduler handles it

# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
assert num_new_tokens > 0

exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)

# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens = torch.where(exceeds_max_model_len, 1,
attn_metadata.seq_lens)

# Compute the slot mapping.
block_numbers = positions // self.block_size
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
positions % self.block_size)
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping = torch.where(
exceeds_max_model_len,
PADDING_SLOT_ID,
attn_metadata.slot_mapping,
)
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that torch.where will allocate an intermediate tensor and then assign it.
Is it possible to use attn_metadata.slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID so that it's an in-place operation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Thanks for the suggestion. I changed it to masked_fill_, which is also an in-place operation.
Overall, I think the performance impact will be small since the tensors here are small (shape of [batch_size]).


# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=positions,
positions=clamped_positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(self, vllm_config: VllmConfig):
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len

# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))
Expand Down Expand Up @@ -50,9 +53,11 @@ def propose(
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# Do not generate draft tokens beyond the max model length.
k = min(self.k, self.max_model_len - context_token_ids.shape[0])
# TODO(woosuk): Optimize this.
for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, self.k)
result = _find_subarray_kmp(context_token_ids, n, k)
if result is not None:
return result
return None
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,8 @@ def generate_draft_token_ids(
draft_token_ids.append([])
continue

# Skip requests that require top-p, top-k, etc.
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = self.input_batch.req_ids[i]
if not is_spec_decode_supported(req_id, self.input_batch):
draft_token_ids.append([])
Expand All @@ -1244,6 +1245,11 @@ def generate_draft_token_ids(
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
if end_idx >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue

self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
Expand Down