Skip to content

Commit e0b6707

Browse files
NickLuccheMu Huai
authored andcommitted
[TPU][V1] Fix padding recompilation when max-num-batched-tokens is not even (vllm-project#16726)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 0a44a1f commit e0b6707

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,19 @@ def test_update_states_request_unscheduled(model_runner):
294294

295295

296296
def test_get_paddings():
297+
# Bucketed padding
297298
min_token_size, max_token_size, padding_gap = 16, 512, 64
298299
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
300+
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
301+
padding_gap)
302+
303+
# Bucketed padding with max_token_size not a power of two.
304+
max_token_size = 317
305+
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
299306
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
300307
padding_gap)
301308
assert actual_paddings == expected_paddings
309+
302310
# Exponential padding.
303311
max_token_size, padding_gap = 1024, 0
304312
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]

vllm/v1/worker/tpu_model_runner.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,16 @@ def __init__(
128128
self.block_size = cache_config.block_size
129129
self.max_model_len = model_config.max_model_len
130130
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
131-
self.max_num_tokens = scheduler_config.max_num_batched_tokens
132131
# InputBatch needs to work with sampling tensors greater than padding
133132
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
134133
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
134+
self.num_tokens_paddings = _get_token_paddings(
135+
min_token_size=16,
136+
max_token_size=scheduler_config.max_num_batched_tokens,
137+
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
138+
# In case `max_num_tokens < max(num_tokens_paddings)` use the actual
139+
# padded max value to pre-allocate data structures and pre-compile.
140+
self.max_num_tokens = self.num_tokens_paddings[-1]
135141

136142
# Model-related.
137143
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@@ -211,10 +217,6 @@ def __init__(
211217
# Range tensor with values [0 .. self.max_num_tokens - 1].
212218
# Used to initialize positions / context_lens / seq_lens
213219
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
214-
self.num_tokens_paddings = _get_token_paddings(
215-
min_token_size=16,
216-
max_token_size=self.max_num_tokens,
217-
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
218220
self.num_reqs_paddings = _get_req_paddings(
219221
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
220222

vllm/v1/worker/tpu_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def determine_available_memory(self) -> int:
156156
self.vllm_config.compilation_config.static_forward_context,
157157
runner_kv_caches)
158158

159-
self.model_runner._dummy_run(
160-
self.scheduler_config.max_num_batched_tokens)
159+
# `max_num_tokens >= max_num_batched_tokens` due to padding.
160+
self.model_runner._dummy_run(self.model_runner.max_num_tokens)
161161

162162
# Synchronize before measuring the memory usage.
163163
xm.wait_device_ops()

0 commit comments

Comments
 (0)