@@ -218,8 +218,9 @@ def load_model(self):
218
218
print (f"[SpyreWorker] Warming up for prompt length { prompt_len } , "
219
219
f"decoding { num_decode_tokens } tokens with batch "
220
220
f"size { batch_size } " )
221
- self ._warmup_spyre_fixed_size (prompt_len , num_decode_tokens ,
222
- restricted_tokens , batch_size )
221
+ with _maybe_warmup_context ():
222
+ self ._warmup_spyre_fixed_size (prompt_len , num_decode_tokens ,
223
+ restricted_tokens , batch_size )
223
224
all_warmup_end_t = time .time ()
224
225
all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
225
226
self .perf_metrics .log ("total warmup time" , all_warmup_total_t )
@@ -262,13 +263,11 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
262
263
f"{ prompt_len } and max output tokens { num_decode_tokens } ." )
263
264
264
265
print ("[SpyreWorker] warmup 1/2..." )
265
- with _maybe_warmup_context ():
266
- # TODO: torch_sendnn.CleanGraph() should be necessary?
267
- # warmup 1st forward pass
268
- self ._warmup_model_forward_pass (warmup_tokens_tensor ,
269
- valid_token_ids_tensor , prompt_len ,
270
- num_decode_tokens , batch_size ,
271
- extra_kwargs )
266
+ # warmup 1st forward pass
267
+ self ._warmup_model_forward_pass (warmup_tokens_tensor ,
268
+ valid_token_ids_tensor , prompt_len ,
269
+ num_decode_tokens , batch_size ,
270
+ extra_kwargs )
272
271
self .perf_metrics .log ("warmup 1 time" ,
273
272
time .time () - warmup_start_t ,
274
273
batch_size = batch_size ,
0 commit comments