Skip to content

[TPU][V1] Fix padding recompilation when max-num-batched-tokens is not even #16726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,19 @@ def test_update_states_request_unscheduled(model_runner):


def test_get_paddings():
# Bucketed padding
min_token_size, max_token_size, padding_gap = 16, 512, 64
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
padding_gap)

# Bucketed padding with max_token_size not a power of two.
max_token_size = 317
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also add a test when padding_gap == 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did in a previous PR it's below. That is for exponential padding.

padding_gap)
assert actual_paddings == expected_paddings

# Exponential padding.
max_token_size, padding_gap = 1024, 0
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
Expand Down
12 changes: 7 additions & 5 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,16 @@ def __init__(
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=scheduler_config.max_num_batched_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
# In case `max_num_tokens < max(num_tokens_paddings)` use the actual
# padded max value to pre-allocate data structures and pre-compile.
self.max_num_tokens = self.num_tokens_paddings[-1]

# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
Expand Down Expand Up @@ -211,10 +217,6 @@ def __init__(
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)

Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def determine_available_memory(self) -> int:
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)

self.model_runner._dummy_run(
self.scheduler_config.max_num_batched_tokens)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
self.model_runner._dummy_run(self.model_runner.max_num_tokens)

# Synchronize before measuring the memory usage.
xm.wait_device_ops()
Expand Down