Skip to content

[V1][Spec Decode] Handle draft tokens beyond max_model_len #16087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 21, 2025
41 changes: 38 additions & 3 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata

PADDING_SLOT_ID = -1


class EagleProposer:

Expand All @@ -20,7 +22,10 @@ def __init__(
self.vllm_config = vllm_config
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size

# Optimization: pre-compute and cache the arange tensor.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)

Expand Down Expand Up @@ -103,22 +108,52 @@ def propose(
# Update the inputs.
input_ids = draft_token_ids_list[-1]
positions += 1

# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove the request from the batch, we keep the request in the
# batch but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated through this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)

# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize the overheads in attention.
attn_metadata.seq_lens = torch.where(exceeds_max_model_len, 1,
attn_metadata.seq_lens)

# Compute the slot mapping.
block_numbers = positions // self.block_size
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
positions % self.block_size)
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadverently updated with the
# padding tokens.
attn_metadata.slot_mapping = torch.where(
exceeds_max_model_len,
PADDING_SLOT_ID,
attn_metadata.slot_mapping,
)
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that torch.where will allocate an intermediate tensor and then assign it.
Is it possible to use attn_metadata.slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID so that it's an in-place operation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Thanks for the suggestion. I changed it to masked_fill_, which is also an in-place operation.
Overall, I think the performance impact will be small since the tensors here are small (shape of [batch_size]).


# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=positions,
positions=clamped_positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
Expand Down