74
74
75
75
LORA_WARMUP_RANK = 8
76
76
77
+ DUMMY_TOKEN_ID = - 1
78
+
77
79
78
80
class Singleton (type ):
79
81
_instances : Dict [type , object ] = {}
@@ -668,6 +670,9 @@ def __init__(
668
670
669
671
# For multi-step scheduling
670
672
self .cached_step_outputs : List [torch .Tensor ] = []
673
+ # For delayed sampling
674
+ self .cached_step_inputs : List [
675
+ ModelInputForHPUWithSamplingMetadata ] = []
671
676
672
677
def _set_gc_threshold (self ) -> None :
673
678
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -771,6 +776,12 @@ def load_model(self) -> None:
771
776
msg = f"Loading model weights took in total { m .get_summary_string ()} "
772
777
logger .info (msg )
773
778
779
+ def _maybe_wrap_in_hpu_graph (self , * args , ** kwargs ):
780
+ return htorch .hpu .wrap_in_hpu_graph (
781
+ HpuModelAdapter (* args , ** kwargs ), disable_tensor_cache = True
782
+ ) if htorch .utils .internal .is_lazy () else HpuModelAdapter (
783
+ * args , ** kwargs )
784
+
774
785
def get_model (self ) -> nn .Module :
775
786
return self .model
776
787
@@ -2020,6 +2031,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
2020
2031
2021
2032
return lora_mask , lora_logits_mask
2022
2033
2034
+ def _get_seq_ids (self , model_input ):
2035
+ return ([
2036
+ sg .seq_ids [0 ] for sg in model_input .sampling_metadata .seq_groups
2037
+ ])
2038
+
2039
+ def _pad_to_max_num_seqs (self , tensor , value ):
2040
+ padding_needed = self .max_num_seqs - tensor .size (0 )
2041
+ if padding_needed :
2042
+ padding = torch .full ((padding_needed , * tensor .shape [1 :]),
2043
+ value ,
2044
+ device = tensor .device ,
2045
+ dtype = tensor .dtype )
2046
+ tensor = torch .cat ([tensor , padding ])
2047
+ return tensor
2048
+
2023
2049
@torch .inference_mode ()
2024
2050
def execute_model (
2025
2051
self ,
@@ -2030,6 +2056,37 @@ def execute_model(
2030
2056
warmup_mode = False ,
2031
2057
seqs = None ,
2032
2058
) -> Optional [Union [List [SamplerOutput ], IntermediateTensors ]]:
2059
+ VLLM_DELAYED_SAMPLING = envs .VLLM_HPU_USE_DELAYED_SAMPLING
2060
+ use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
2061
+ assert not (use_delayed_sampling and num_steps != 1 ), \
2062
+ 'Delayed sampling is not compatible with MSS!'
2063
+ assert model_input .input_tokens is not None
2064
+ if use_delayed_sampling and not model_input .is_prompt and \
2065
+ self .is_driver_worker :
2066
+ num_cached = len (self .cached_step_outputs )
2067
+ assert num_cached > 0
2068
+ cur_seq_ids = self ._get_seq_ids (model_input )
2069
+ cur_seq_id_pos = {
2070
+ sid : idx
2071
+ for idx , sid in enumerate (cur_seq_ids ) if sid >= 0
2072
+ }
2073
+ htorch .core .mark_step ()
2074
+ for i in range (num_cached ):
2075
+ prev_seq_ids = self ._get_seq_ids (self .cached_step_inputs [i ])
2076
+ target_indices = [
2077
+ cur_seq_id_pos .get (psi , - 1 ) for psi in prev_seq_ids
2078
+ ]
2079
+ padding = self .cached_step_outputs [i ].size (0 ) - len (
2080
+ target_indices )
2081
+ target_indices .extend ([- 1 ] * padding )
2082
+ target_indices = torch .tensor (
2083
+ target_indices ,
2084
+ device = model_input .input_tokens .device ,
2085
+ dtype = model_input .input_tokens .dtype )
2086
+ model_input .input_tokens .index_copy_ (
2087
+ 0 , target_indices , self .cached_step_outputs [i ])
2088
+ htorch .core .mark_step ()
2089
+
2033
2090
if not model_input .is_first_multi_step :
2034
2091
if not model_input .is_last_step :
2035
2092
# not first or last multi-step
@@ -2045,7 +2102,21 @@ def execute_model(
2045
2102
assert model_input .lora_mapping is not None
2046
2103
self .set_active_loras (model_input .lora_requests ,
2047
2104
model_input .lora_mapping )
2048
- input_tokens = model_input .input_tokens
2105
+ # Rank!=0 workers has is_prompt==None
2106
+ if use_delayed_sampling and not model_input .is_prompt and \
2107
+ model_input .input_tokens .size (1 ) == 1 :
2108
+ if self .is_driver_worker :
2109
+ model_kwargs_broadcast_data = {
2110
+ "input_tokens" : model_input .input_tokens
2111
+ }
2112
+ broadcast_tensor_dict (model_kwargs_broadcast_data , src = 0 )
2113
+ input_tokens = model_input .input_tokens
2114
+
2115
+ else :
2116
+ model_kwargs_broadcast_data = broadcast_tensor_dict (src = 0 )
2117
+ input_tokens = model_kwargs_broadcast_data ["input_tokens" ]
2118
+ else :
2119
+ input_tokens = model_input .input_tokens
2049
2120
input_positions = model_input .input_positions
2050
2121
attn_metadata = model_input .attn_metadata
2051
2122
sampling_metadata = model_input .sampling_metadata
@@ -2092,7 +2163,7 @@ def execute_model(
2092
2163
f"graphs{ 'T' if use_graphs else 'F' } " )
2093
2164
else :
2094
2165
model_event_name = 'model_executable'
2095
- if num_steps > 1 :
2166
+ if num_steps > 1 or use_delayed_sampling :
2096
2167
# in case of multi-step scheduling
2097
2168
# we only want to pythonize in the last step
2098
2169
sampling_metadata .skip_sampler_cpu_output = True
@@ -2152,9 +2223,9 @@ def try_revert_dummy_output_tokens():
2152
2223
if not self .is_driver_worker :
2153
2224
continue
2154
2225
2155
- if model_input . async_callback is not None :
2156
- model_input . async_callback ( )
2157
- # Sample the next token.
2226
+ if use_delayed_sampling :
2227
+ fake_output = self . _delayed_sampler_outputs ( model_input )
2228
+
2158
2229
with self .profiler .record_event (
2159
2230
'internal' , ('sample_'
2160
2231
f'{ "prompt" if is_prompt else "decode" } _'
@@ -2166,9 +2237,16 @@ def try_revert_dummy_output_tokens():
2166
2237
)
2167
2238
if num_steps > 1 :
2168
2239
output = output .sampled_token_ids
2169
- self .cached_step_outputs .append (
2170
- output .detach ().clone ())
2240
+ self .cached_step_outputs .append (output )
2241
+ if use_delayed_sampling and self .is_driver_worker :
2242
+ self ._patch_prev_output ()
2243
+ output = self ._pad_to_max_num_seqs (
2244
+ output .sampled_token_ids , DUMMY_TOKEN_ID )
2245
+ self .cached_step_outputs .append (output )
2246
+ self .cached_step_inputs .append (model_input )
2171
2247
htorch .core .mark_step ()
2248
+ if model_input .async_callback is not None :
2249
+ model_input .async_callback ()
2172
2250
if i < num_steps - 1 :
2173
2251
if i == 0 :
2174
2252
if model_input .async_callback is not None :
@@ -2241,11 +2319,30 @@ def try_revert_dummy_output_tokens():
2241
2319
is_prompt = is_prompt )
2242
2320
self .profiler .record_counter (self .event_start , counters )
2243
2321
if num_steps == 1 :
2322
+ if self .return_hidden_states :
2323
+ # we only need to pass hidden states of most recent token
2324
+ assert model_input .sampling_metadata is not None
2325
+ if model_input .is_prompt :
2326
+ output .prefill_hidden_states = hidden_states
2327
+ output .hidden_states = hidden_states
2328
+ if use_delayed_sampling :
2329
+ if self .is_driver_worker :
2330
+ return [fake_output ]
2331
+ else :
2332
+ return []
2333
+
2244
2334
return [output ] if self .is_driver_worker else []
2245
2335
else :
2246
2336
return []
2247
2337
return output if type (output ) is list else [output ]
2248
2338
2339
+ def _delayed_sampler_outputs (self , model_input ):
2340
+ next_token_ids = [[DUMMY_TOKEN_ID ]] * len (
2341
+ model_input .sampling_metadata .seq_groups )
2342
+ sampler_output = self ._make_decode_output (
2343
+ next_token_ids , model_input .sampling_metadata .seq_groups )
2344
+ return sampler_output
2345
+
2249
2346
def _decode_sampler_outputs (self , model_input ):
2250
2347
use_async_out_proc = model_input .async_callback is not None
2251
2348
sampler_outputs = []
@@ -2312,3 +2409,32 @@ def shutdown_inc(self):
2312
2409
2313
2410
def __del__ (self ):
2314
2411
self .shutdown_inc ()
2412
+
2413
+ def _patch_prev_output (self ):
2414
+ assert len (self .cached_step_inputs ) == len (self .cached_step_outputs ), \
2415
+ f'''Inputs and outputs are out of sync!
2416
+ { len (self .cached_step_inputs )} vs { len (self .cached_step_outputs )} '''
2417
+ if len (self .cached_step_inputs ) == 0 :
2418
+ return
2419
+ model_input = self .cached_step_inputs .pop (0 )
2420
+ delayed_output = self .cached_step_outputs .pop (0 ).cpu ().squeeze (
2421
+ - 1 ).tolist ()
2422
+ ctx = model_input .async_callback .keywords ["ctx" ] # type: ignore
2423
+ # If there's no output to patch with, which is usually the case when
2424
+ # we're starting a new request after all requests are completed.
2425
+ if len (ctx .output_queue ) == 0 :
2426
+ return
2427
+ assert len (
2428
+ ctx .output_queue ) == 1 , 'There should be exactly 1 output waiting!'
2429
+ output_data = ctx .output_queue [0 ]
2430
+ assert len (output_data .outputs ) == 1
2431
+ for fake_out , real_out in zip (output_data .outputs [0 ], delayed_output ):
2432
+ fake_out .samples [0 ].output_token = real_out
2433
+ for sg , real_out in zip (output_data .seq_group_metadata_list ,
2434
+ delayed_output ):
2435
+ assert len (sg .seq_data ) == 1
2436
+ seq_data = list (sg .seq_data .values ())[0 ]
2437
+ # This is a hack. Assigning output_token_ids triggers
2438
+ # a cache recomputation and we only need to update the last token
2439
+ seq_data .output_token_ids_array [- 1 ] = real_out
2440
+ seq_data ._cached_all_token_ids [- 1 ] = real_out
0 commit comments