Skip to content

Commit bf1a1a3

Browse files
committed
apply to v0: fixed issue with warmup_context not capturing full generate
Signed-off-by: Yannick Schnider <[email protected]>
1 parent e5a582a commit bf1a1a3

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

vllm_spyre/worker/spyre_worker.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,9 @@ def load_model(self):
218218
print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, "
219219
f"decoding {num_decode_tokens} tokens with batch "
220220
f"size {batch_size}")
221-
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
222-
restricted_tokens, batch_size)
221+
with _maybe_warmup_context():
222+
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
223+
restricted_tokens, batch_size)
223224
all_warmup_end_t = time.time()
224225
all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
225226
self.perf_metrics.log("total warmup time", all_warmup_total_t)
@@ -262,13 +263,11 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
262263
f"{prompt_len} and max output tokens {num_decode_tokens}.")
263264

264265
print("[SpyreWorker] warmup 1/2...")
265-
with _maybe_warmup_context():
266-
# TODO: torch_sendnn.CleanGraph() should be necessary?
267-
# warmup 1st forward pass
268-
self._warmup_model_forward_pass(warmup_tokens_tensor,
269-
valid_token_ids_tensor, prompt_len,
270-
num_decode_tokens, batch_size,
271-
extra_kwargs)
266+
# warmup 1st forward pass
267+
self._warmup_model_forward_pass(warmup_tokens_tensor,
268+
valid_token_ids_tensor, prompt_len,
269+
num_decode_tokens, batch_size,
270+
extra_kwargs)
272271
self.perf_metrics.log("warmup 1 time",
273272
time.time() - warmup_start_t,
274273
batch_size=batch_size,

0 commit comments

Comments
 (0)