File tree Expand file tree Collapse file tree 1 file changed +35
-0
lines changed Expand file tree Collapse file tree 1 file changed +35
-0
lines changed Original file line number Diff line number Diff line change 53
53
MIN_NUM_SEQS = 8
54
54
55
55
56
+ #########################################################
57
+ # Ways to avoid recompilation
58
+ #########################################################
59
+ #
60
+ # The model executor has two primary components:
61
+ # 1. preparing the model and sampler inputs
62
+ # 2. executing the model and sampler.
63
+ # The core idea is to avoid any TPU computation during input preparation. For
64
+ # better compilation tracking and increased flexibility, the model execution and
65
+ # sampler are divided into several distinct components.
66
+ #
67
+ # Below are the detailed steps:
68
+ #
69
+ # Step 1
70
+ # It is recommended to avoid TPU operations when preparing the model and sampler
71
+ # inputs. CPU tensors can be prepared and transferred to the XLA device using
72
+ # cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
73
+ # compilation.
74
+ #
75
+ # Step 2
76
+ # The TPU execution should be decomposed into subgraphs (4 at the moment):
77
+ # 1. the main model
78
+ # 2. selecting hidden states for each request
79
+ # 3. sampler
80
+ # 4. encoder.
81
+ # Each subgraph should be decorated in a torch.compile. This is used to make
82
+ # sure that we have the same subgraph topology in both dummy_run and
83
+ # xecute_model. The results from these subgraphs should either be passed to
84
+ # other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
85
+ # subsequent processing on the CPU.
86
+ #
87
+ # Step 3
88
+ # The dummy_run should be comprehensive, ensuring all potential input shapes and
89
+ # branch predictions are included as subgraph inputs to facilitate
90
+ # pre-compilation.
56
91
class TPUModelRunner :
57
92
58
93
def __init__ (
You can’t perform that action at this time.
0 commit comments