Skip to content

[V1][Spec Decode] Handle draft tokens beyond max_model_len #16087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 21, 2025
Merged

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Apr 5, 2025

Implements 4. Handle the edge cases like when the draft model generates beyond max_pos_embeddings in #15901

Copy link

github-actions bot commented Apr 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@WoosukKwon WoosukKwon changed the title [Spec Decode] Do not generate draft tokens beyond max_model_len [V1][Spec Decode] Do not generate draft tokens beyond max_model_len Apr 5, 2025
@mergify mergify bot added the v1 label Apr 5, 2025
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Approve first to unblock the PR. Meanwhile it would be good to have a unit test for it. Also do we know the overhead of introduced ops (e.g., torch.where)?

@comaniac comaniac added the needs-tests Tests needed for this PR label Apr 5, 2025
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Comment on lines 145 to 149
attn_metadata.slot_mapping = torch.where(
exceeds_max_model_len,
PADDING_SLOT_ID,
attn_metadata.slot_mapping,
)
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that torch.where will allocate an intermediate tensor and then assign it.
Is it possible to use attn_metadata.slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID so that it's an in-place operation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Thanks for the suggestion. I changed it to masked_fill_, which is also an in-place operation.
Overall, I think the performance impact will be small since the tensors here are small (shape of [batch_size]).

Comment on lines +116 to +117
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please point me to the logic of ignoring such draft tokens in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. The scheduler handles it

# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
assert num_new_tokens > 0

Copy link

mergify bot commented Apr 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Apr 21, 2025
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
@WoosukKwon WoosukKwon added ready ONLY add when PR is ready to merge/full CI is needed and removed needs-tests Tests needed for this PR labels Apr 21, 2025
@WoosukKwon
Copy link
Collaborator Author

@comaniac Good point. Added a test. Also, I replaced torch.where with masked_fill_ wherever possible, for better performance. Overall, I think the overhead will be very small because the masked tensors are all small (i.e., shape of [batch_size]).

Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
@WoosukKwon WoosukKwon merged commit 3a0fba5 into main Apr 21, 2025
42 of 44 checks passed
@WoosukKwon WoosukKwon deleted the eagle-max-len branch April 21, 2025 19:38
@WoosukKwon WoosukKwon changed the title [V1][Spec Decode] Do not generate draft tokens beyond max_model_len [V1][Spec Decode] Handle draft tokens beyond max_model_len Apr 21, 2025
frieda-huang pushed a commit to frieda-huang/vllm that referenced this pull request Apr 23, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants