Skip to content

[TPU][V1] Fix padding recompilation when max-num-batched-tokens is not even #16726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Apr 16, 2025

max-num-batched-tokens values that are not power of 2s (or simply not even when using bucketing) can silently cause recompiliations.
This is due to the fact in such cases both bucketed and exponential padding will return a maximal padding value that will be > max-num-batched-tokens.
Assume:

# 317
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.num_tokens_paddings = _get_token_paddings(..)
print(self.num_tokens_paddings[-1]) # 512 for exponential, 320 for bucketed

In turn, auxiliary data structures will be instantiated with the original uneven value (say 317)

self.input_ids_cpu = torch.zeros(self.max_num_tokens, # 317
                                 dtype=torch.int32,
                                 device="cpu")

Causing the following

# Padded value=512, which is > 317
padded_total_num_scheduled_tokens = _get_padded_token_len(
    self.num_tokens_paddings, total_num_scheduled_tokens)
...
# Select 512 positions in 317 position array will yield the full 317 positions
self.input_ids = self.input_ids_cpu[:
                                    padded_total_num_scheduled_tokens].to(
                                        self.device)
# Causing `input_ids` to run with a potentially uncompiled size of 317

To verify:

VLLM_XLA_CACHE_PATH= VLLM_XLA_CHECK_RECOMPILATION=1 VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct \
 --port 8004 \
 --gpu-memory-utilization 0.95 \
 --max-num-seqs 8 \
 --max-num-batched-tokens 93 \
 --tensor-parallel-size 1 \
 --max-model-len 256
 
 
 # Some long request
 curl http://localhost:8004/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "Qwen/Qwen2.5-1.5B-Instruct",
    "messages": [
      {
        "role": "system",
        "content": "You are a helpful assistant."
      },
      {
        "role": "user",
        "content": "Imagine a lone samurai wandering across feudal Japan, passing through misty mountains, quiet villages, and fields touched by the morning dew. He reflects on the impermanence of life, the fleeting nature of honor, and the silent strength required to walk a solitary path. His sword is sheathed not just in steel, but in wisdom. His heart carries both the burden of war and the beauty of cherry blossoms falling in the spring breeze. He meets no one, yet learns from every shadow, every whisper of wind, and every broken piece of pottery he finds along the way. Please write a short poem that captures the spirit of this samurai journey and his internal reflections. The tone should be serene, slightly melancholic, yet filled with a sense of dignity and resolve."
      }
    ],
    "temperature": 0.4,
    "min_p": 0.8,
    "max_tokens": 48,
    "stream": false
  }'

We can either have the last padding element be always exactly equal to max-num-batched-tokens or adjust for its "actual" padded value computed by _get_token_paddings (a power of 2 or "bucketed" power of 2).

This PR implements the latter because I ultimately think having these sizes be aligned can help with memory accesses here and there.
I am open to discuss better approaches if you have suggestions.

Signed-off-by: NickLucche <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 16, 2025
# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also add a test when padding_gap == 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did in a previous PR it's below. That is for exponential padding.

@NickLucche
Copy link
Contributor Author

Thanks for reviewing @yaochengji .

Signed-off-by: NickLucche <[email protected]>
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) April 17, 2025 16:27
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 17, 2025
@robertgshaw2-redhat robertgshaw2-redhat merged commit 5989f46 into vllm-project:main Apr 17, 2025
59 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants