Skip to content

Commit c948bc5

Browse files
yaochengjiMu Huai
authored andcommitted
[DOC][TPU] Add core idea about avoiding recompilation after warmup (vllm-project#16614)
Signed-off-by: Chengji Yao <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 70f29d3 commit c948bc5

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,41 @@
5353
MIN_NUM_SEQS = 8
5454

5555

56+
#########################################################
57+
# Ways to avoid recompilation
58+
#########################################################
59+
#
60+
# The model executor has two primary components:
61+
# 1. preparing the model and sampler inputs
62+
# 2. executing the model and sampler.
63+
# The core idea is to avoid any TPU computation during input preparation. For
64+
# better compilation tracking and increased flexibility, the model execution and
65+
# sampler are divided into several distinct components.
66+
#
67+
# Below are the detailed steps:
68+
#
69+
# Step 1
70+
# It is recommended to avoid TPU operations when preparing the model and sampler
71+
# inputs. CPU tensors can be prepared and transferred to the XLA device using
72+
# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
73+
# compilation.
74+
#
75+
# Step 2
76+
# The TPU execution should be decomposed into subgraphs (4 at the moment):
77+
# 1. the main model
78+
# 2. selecting hidden states for each request
79+
# 3. sampler
80+
# 4. encoder.
81+
# Each subgraph should be decorated in a torch.compile. This is used to make
82+
# sure that we have the same subgraph topology in both dummy_run and
83+
# xecute_model. The results from these subgraphs should either be passed to
84+
# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
85+
# subsequent processing on the CPU.
86+
#
87+
# Step 3
88+
# The dummy_run should be comprehensive, ensuring all potential input shapes and
89+
# branch predictions are included as subgraph inputs to facilitate
90+
# pre-compilation.
5691
class TPUModelRunner:
5792

5893
def __init__(

0 commit comments

Comments
 (0)