75
75
76
76
LORA_WARMUP_RANK = 8
77
77
78
+ DUMMY_TOKEN_ID = - 1
79
+
78
80
79
81
class PhaseType (Enum ):
80
82
PREFILL = 'prefill'
@@ -701,6 +703,9 @@ def __init__(
701
703
702
704
# For multi-step scheduling
703
705
self .cached_step_outputs : List [torch .Tensor ] = []
706
+ # For delayed sampling
707
+ self .cached_step_inputs : List [
708
+ ModelInputForHPUWithSamplingMetadata ] = []
704
709
705
710
def _set_gc_threshold (self ) -> None :
706
711
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -804,6 +809,12 @@ def load_model(self) -> None:
804
809
msg = f"Loading model weights took in total { m .get_summary_string ()} "
805
810
logger .info (msg )
806
811
812
+ def _maybe_wrap_in_hpu_graph (self , * args , ** kwargs ):
813
+ return htorch .hpu .wrap_in_hpu_graph (
814
+ HpuModelAdapter (* args , ** kwargs ), disable_tensor_cache = True
815
+ ) if htorch .utils .internal .is_lazy () else HpuModelAdapter (
816
+ * args , ** kwargs )
817
+
807
818
def get_model (self ) -> nn .Module :
808
819
return self .model
809
820
@@ -2115,6 +2126,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
2115
2126
2116
2127
return lora_mask , lora_logits_mask
2117
2128
2129
+ def _get_seq_ids (self , model_input ):
2130
+ return ([
2131
+ sg .seq_ids [0 ] for sg in model_input .sampling_metadata .seq_groups
2132
+ ])
2133
+
2134
+ def _pad_to_max_num_seqs (self , tensor , value ):
2135
+ padding_needed = self .max_num_seqs - tensor .size (0 )
2136
+ if padding_needed :
2137
+ padding = torch .full ((padding_needed , * tensor .shape [1 :]),
2138
+ value ,
2139
+ device = tensor .device ,
2140
+ dtype = tensor .dtype )
2141
+ tensor = torch .cat ([tensor , padding ])
2142
+ return tensor
2143
+
2118
2144
@torch .inference_mode ()
2119
2145
def execute_model (
2120
2146
self ,
@@ -2125,6 +2151,37 @@ def execute_model(
2125
2151
warmup_mode = False ,
2126
2152
seqs = None ,
2127
2153
) -> Optional [Union [List [SamplerOutput ], IntermediateTensors ]]:
2154
+ VLLM_DELAYED_SAMPLING = envs .VLLM_HPU_USE_DELAYED_SAMPLING
2155
+ use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
2156
+ assert not (use_delayed_sampling and num_steps != 1 ), \
2157
+ 'Delayed sampling is not compatible with MSS!'
2158
+ assert model_input .input_tokens is not None
2159
+ if use_delayed_sampling and not model_input .is_prompt and \
2160
+ self .is_driver_worker :
2161
+ num_cached = len (self .cached_step_outputs )
2162
+ assert num_cached > 0
2163
+ cur_seq_ids = self ._get_seq_ids (model_input )
2164
+ cur_seq_id_pos = {
2165
+ sid : idx
2166
+ for idx , sid in enumerate (cur_seq_ids ) if sid >= 0
2167
+ }
2168
+ htorch .core .mark_step ()
2169
+ for i in range (num_cached ):
2170
+ prev_seq_ids = self ._get_seq_ids (self .cached_step_inputs [i ])
2171
+ target_indices = [
2172
+ cur_seq_id_pos .get (psi , - 1 ) for psi in prev_seq_ids
2173
+ ]
2174
+ padding = self .cached_step_outputs [i ].size (0 ) - len (
2175
+ target_indices )
2176
+ target_indices .extend ([- 1 ] * padding )
2177
+ target_indices = torch .tensor (
2178
+ target_indices ,
2179
+ device = model_input .input_tokens .device ,
2180
+ dtype = model_input .input_tokens .dtype )
2181
+ model_input .input_tokens .index_copy_ (
2182
+ 0 , target_indices , self .cached_step_outputs [i ])
2183
+ htorch .core .mark_step ()
2184
+
2128
2185
if not model_input .is_first_multi_step :
2129
2186
if not model_input .is_last_step :
2130
2187
# not first or last multi-step
@@ -2140,7 +2197,21 @@ def execute_model(
2140
2197
assert model_input .lora_mapping is not None
2141
2198
self .set_active_loras (model_input .lora_requests ,
2142
2199
model_input .lora_mapping )
2143
- input_tokens = model_input .input_tokens
2200
+ # Rank!=0 workers has is_prompt==None
2201
+ if use_delayed_sampling and not model_input .is_prompt and \
2202
+ model_input .input_tokens .size (1 ) == 1 :
2203
+ if self .is_driver_worker :
2204
+ model_kwargs_broadcast_data = {
2205
+ "input_tokens" : model_input .input_tokens
2206
+ }
2207
+ broadcast_tensor_dict (model_kwargs_broadcast_data , src = 0 )
2208
+ input_tokens = model_input .input_tokens
2209
+
2210
+ else :
2211
+ model_kwargs_broadcast_data = broadcast_tensor_dict (src = 0 )
2212
+ input_tokens = model_kwargs_broadcast_data ["input_tokens" ]
2213
+ else :
2214
+ input_tokens = model_input .input_tokens
2144
2215
input_positions = model_input .input_positions
2145
2216
attn_metadata = model_input .attn_metadata
2146
2217
sampling_metadata = model_input .sampling_metadata
@@ -2187,7 +2258,7 @@ def execute_model(
2187
2258
f"graphs{ 'T' if use_graphs else 'F' } " )
2188
2259
else :
2189
2260
model_event_name = 'model_executable'
2190
- if num_steps > 1 :
2261
+ if num_steps > 1 or use_delayed_sampling :
2191
2262
# in case of multi-step scheduling
2192
2263
# we only want to pythonize in the last step
2193
2264
sampling_metadata .skip_sampler_cpu_output = True
@@ -2247,9 +2318,9 @@ def try_revert_dummy_output_tokens():
2247
2318
if not self .is_driver_worker :
2248
2319
continue
2249
2320
2250
- if model_input . async_callback is not None :
2251
- model_input . async_callback ( )
2252
- # Sample the next token.
2321
+ if use_delayed_sampling :
2322
+ fake_output = self . _delayed_sampler_outputs ( model_input )
2323
+
2253
2324
with self .profiler .record_event (
2254
2325
'internal' , ('sample_'
2255
2326
f'{ "prompt" if is_prompt else "decode" } _'
@@ -2261,9 +2332,16 @@ def try_revert_dummy_output_tokens():
2261
2332
)
2262
2333
if num_steps > 1 :
2263
2334
output = output .sampled_token_ids
2264
- self .cached_step_outputs .append (
2265
- output .detach ().clone ())
2335
+ self .cached_step_outputs .append (output )
2336
+ if use_delayed_sampling and self .is_driver_worker :
2337
+ self ._patch_prev_output ()
2338
+ output = self ._pad_to_max_num_seqs (
2339
+ output .sampled_token_ids , DUMMY_TOKEN_ID )
2340
+ self .cached_step_outputs .append (output )
2341
+ self .cached_step_inputs .append (model_input )
2266
2342
htorch .core .mark_step ()
2343
+ if model_input .async_callback is not None :
2344
+ model_input .async_callback ()
2267
2345
if i < num_steps - 1 :
2268
2346
if i == 0 :
2269
2347
if model_input .async_callback is not None :
@@ -2336,11 +2414,30 @@ def try_revert_dummy_output_tokens():
2336
2414
is_prompt = is_prompt )
2337
2415
self .profiler .record_counter (self .event_start , counters )
2338
2416
if num_steps == 1 :
2417
+ if self .return_hidden_states :
2418
+ # we only need to pass hidden states of most recent token
2419
+ assert model_input .sampling_metadata is not None
2420
+ if model_input .is_prompt :
2421
+ output .prefill_hidden_states = hidden_states
2422
+ output .hidden_states = hidden_states
2423
+ if use_delayed_sampling :
2424
+ if self .is_driver_worker :
2425
+ return [fake_output ]
2426
+ else :
2427
+ return []
2428
+
2339
2429
return [output ] if self .is_driver_worker else []
2340
2430
else :
2341
2431
return []
2342
2432
return output if type (output ) is list else [output ]
2343
2433
2434
+ def _delayed_sampler_outputs (self , model_input ):
2435
+ next_token_ids = [[DUMMY_TOKEN_ID ]] * len (
2436
+ model_input .sampling_metadata .seq_groups )
2437
+ sampler_output = self ._make_decode_output (
2438
+ next_token_ids , model_input .sampling_metadata .seq_groups )
2439
+ return sampler_output
2440
+
2344
2441
def _decode_sampler_outputs (self , model_input ):
2345
2442
use_async_out_proc = model_input .async_callback is not None
2346
2443
sampler_outputs = []
@@ -2407,3 +2504,32 @@ def shutdown_inc(self):
2407
2504
2408
2505
def __del__ (self ):
2409
2506
self .shutdown_inc ()
2507
+
2508
+ def _patch_prev_output (self ):
2509
+ assert len (self .cached_step_inputs ) == len (self .cached_step_outputs ), \
2510
+ f'''Inputs and outputs are out of sync!
2511
+ { len (self .cached_step_inputs )} vs { len (self .cached_step_outputs )} '''
2512
+ if len (self .cached_step_inputs ) == 0 :
2513
+ return
2514
+ model_input = self .cached_step_inputs .pop (0 )
2515
+ delayed_output = self .cached_step_outputs .pop (0 ).cpu ().squeeze (
2516
+ - 1 ).tolist ()
2517
+ ctx = model_input .async_callback .keywords ["ctx" ] # type: ignore
2518
+ # If there's no output to patch with, which is usually the case when
2519
+ # we're starting a new request after all requests are completed.
2520
+ if len (ctx .output_queue ) == 0 :
2521
+ return
2522
+ assert len (
2523
+ ctx .output_queue ) == 1 , 'There should be exactly 1 output waiting!'
2524
+ output_data = ctx .output_queue [0 ]
2525
+ assert len (output_data .outputs ) == 1
2526
+ for fake_out , real_out in zip (output_data .outputs [0 ], delayed_output ):
2527
+ fake_out .samples [0 ].output_token = real_out
2528
+ for sg , real_out in zip (output_data .seq_group_metadata_list ,
2529
+ delayed_output ):
2530
+ assert len (sg .seq_data ) == 1
2531
+ seq_data = list (sg .seq_data .values ())[0 ]
2532
+ # This is a hack. Assigning output_token_ids triggers
2533
+ # a cache recomputation and we only need to update the last token
2534
+ seq_data .output_token_ids_array [- 1 ] = real_out
2535
+ seq_data ._cached_all_token_ids [- 1 ] = real_out
0 commit comments