Skip to content

Commit 48f7f49

Browse files
xuechendimadamczyk-intel
authored andcommitted
[INTEL-HPU][v0] Port delayed sampling to upstream (vllm-project#16949)
Signed-off-by: Michal Adamczyk <[email protected]> Signed-off-by: Chendi Xue <[email protected]> Co-authored-by: Michal Adamczyk <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 9ee22de commit 48f7f49

File tree

2 files changed

+140
-7
lines changed

2 files changed

+140
-7
lines changed

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
VLLM_RAY_BUNDLE_INDICES: str = ""
9999
VLLM_CUDART_SO_PATH: Optional[str] = None
100100
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
101+
VLLM_HPU_USE_DELAYED_SAMPLING: bool = False
101102
VLLM_DP_RANK: int = 0
102103
VLLM_DP_RANK_LOCAL: int = -1
103104
VLLM_DP_SIZE: int = 1
@@ -650,6 +651,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
650651
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
651652
("1", "true"),
652653

654+
# Use delayed sampling for HPU to reduce host cpu overhead
655+
# between each step.
656+
"VLLM_HPU_USE_DELAYED_SAMPLING":
657+
lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in
658+
("1", "true"),
659+
653660
# Rank of the process in the data parallel setting
654661
"VLLM_DP_RANK":
655662
lambda: int(os.getenv("VLLM_DP_RANK", "0")),

vllm/worker/hpu_model_runner.py

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474

7575
LORA_WARMUP_RANK = 8
7676

77+
DUMMY_TOKEN_ID = -1
78+
7779

7880
class Singleton(type):
7981
_instances: Dict[type, object] = {}
@@ -668,6 +670,9 @@ def __init__(
668670

669671
# For multi-step scheduling
670672
self.cached_step_outputs: List[torch.Tensor] = []
673+
# For delayed sampling
674+
self.cached_step_inputs: List[
675+
ModelInputForHPUWithSamplingMetadata] = []
671676

672677
def _set_gc_threshold(self) -> None:
673678
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -771,6 +776,12 @@ def load_model(self) -> None:
771776
msg = f"Loading model weights took in total {m.get_summary_string()}"
772777
logger.info(msg)
773778

779+
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
780+
return htorch.hpu.wrap_in_hpu_graph(
781+
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
782+
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
783+
*args, **kwargs)
784+
774785
def get_model(self) -> nn.Module:
775786
return self.model
776787

@@ -2020,6 +2031,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
20202031

20212032
return lora_mask, lora_logits_mask
20222033

2034+
def _get_seq_ids(self, model_input):
2035+
return ([
2036+
sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups
2037+
])
2038+
2039+
def _pad_to_max_num_seqs(self, tensor, value):
2040+
padding_needed = self.max_num_seqs - tensor.size(0)
2041+
if padding_needed:
2042+
padding = torch.full((padding_needed, *tensor.shape[1:]),
2043+
value,
2044+
device=tensor.device,
2045+
dtype=tensor.dtype)
2046+
tensor = torch.cat([tensor, padding])
2047+
return tensor
2048+
20232049
@torch.inference_mode()
20242050
def execute_model(
20252051
self,
@@ -2030,6 +2056,37 @@ def execute_model(
20302056
warmup_mode=False,
20312057
seqs=None,
20322058
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
2059+
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
2060+
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
2061+
assert not (use_delayed_sampling and num_steps != 1), \
2062+
'Delayed sampling is not compatible with MSS!'
2063+
assert model_input.input_tokens is not None
2064+
if use_delayed_sampling and not model_input.is_prompt and \
2065+
self.is_driver_worker:
2066+
num_cached = len(self.cached_step_outputs)
2067+
assert num_cached > 0
2068+
cur_seq_ids = self._get_seq_ids(model_input)
2069+
cur_seq_id_pos = {
2070+
sid: idx
2071+
for idx, sid in enumerate(cur_seq_ids) if sid >= 0
2072+
}
2073+
htorch.core.mark_step()
2074+
for i in range(num_cached):
2075+
prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i])
2076+
target_indices = [
2077+
cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids
2078+
]
2079+
padding = self.cached_step_outputs[i].size(0) - len(
2080+
target_indices)
2081+
target_indices.extend([-1] * padding)
2082+
target_indices = torch.tensor(
2083+
target_indices,
2084+
device=model_input.input_tokens.device,
2085+
dtype=model_input.input_tokens.dtype)
2086+
model_input.input_tokens.index_copy_(
2087+
0, target_indices, self.cached_step_outputs[i])
2088+
htorch.core.mark_step()
2089+
20332090
if not model_input.is_first_multi_step:
20342091
if not model_input.is_last_step:
20352092
# not first or last multi-step
@@ -2045,7 +2102,21 @@ def execute_model(
20452102
assert model_input.lora_mapping is not None
20462103
self.set_active_loras(model_input.lora_requests,
20472104
model_input.lora_mapping)
2048-
input_tokens = model_input.input_tokens
2105+
# Rank!=0 workers has is_prompt==None
2106+
if use_delayed_sampling and not model_input.is_prompt and \
2107+
model_input.input_tokens.size(1) == 1:
2108+
if self.is_driver_worker:
2109+
model_kwargs_broadcast_data = {
2110+
"input_tokens": model_input.input_tokens
2111+
}
2112+
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
2113+
input_tokens = model_input.input_tokens
2114+
2115+
else:
2116+
model_kwargs_broadcast_data = broadcast_tensor_dict(src=0)
2117+
input_tokens = model_kwargs_broadcast_data["input_tokens"]
2118+
else:
2119+
input_tokens = model_input.input_tokens
20492120
input_positions = model_input.input_positions
20502121
attn_metadata = model_input.attn_metadata
20512122
sampling_metadata = model_input.sampling_metadata
@@ -2092,7 +2163,7 @@ def execute_model(
20922163
f"graphs{'T' if use_graphs else 'F'}")
20932164
else:
20942165
model_event_name = 'model_executable'
2095-
if num_steps > 1:
2166+
if num_steps > 1 or use_delayed_sampling:
20962167
# in case of multi-step scheduling
20972168
# we only want to pythonize in the last step
20982169
sampling_metadata.skip_sampler_cpu_output = True
@@ -2152,9 +2223,9 @@ def try_revert_dummy_output_tokens():
21522223
if not self.is_driver_worker:
21532224
continue
21542225

2155-
if model_input.async_callback is not None:
2156-
model_input.async_callback()
2157-
# Sample the next token.
2226+
if use_delayed_sampling:
2227+
fake_output = self._delayed_sampler_outputs(model_input)
2228+
21582229
with self.profiler.record_event(
21592230
'internal', ('sample_'
21602231
f'{"prompt" if is_prompt else "decode"}_'
@@ -2166,9 +2237,16 @@ def try_revert_dummy_output_tokens():
21662237
)
21672238
if num_steps > 1:
21682239
output = output.sampled_token_ids
2169-
self.cached_step_outputs.append(
2170-
output.detach().clone())
2240+
self.cached_step_outputs.append(output)
2241+
if use_delayed_sampling and self.is_driver_worker:
2242+
self._patch_prev_output()
2243+
output = self._pad_to_max_num_seqs(
2244+
output.sampled_token_ids, DUMMY_TOKEN_ID)
2245+
self.cached_step_outputs.append(output)
2246+
self.cached_step_inputs.append(model_input)
21712247
htorch.core.mark_step()
2248+
if model_input.async_callback is not None:
2249+
model_input.async_callback()
21722250
if i < num_steps - 1:
21732251
if i == 0:
21742252
if model_input.async_callback is not None:
@@ -2241,11 +2319,30 @@ def try_revert_dummy_output_tokens():
22412319
is_prompt=is_prompt)
22422320
self.profiler.record_counter(self.event_start, counters)
22432321
if num_steps == 1:
2322+
if self.return_hidden_states:
2323+
# we only need to pass hidden states of most recent token
2324+
assert model_input.sampling_metadata is not None
2325+
if model_input.is_prompt:
2326+
output.prefill_hidden_states = hidden_states
2327+
output.hidden_states = hidden_states
2328+
if use_delayed_sampling:
2329+
if self.is_driver_worker:
2330+
return [fake_output]
2331+
else:
2332+
return []
2333+
22442334
return [output] if self.is_driver_worker else []
22452335
else:
22462336
return []
22472337
return output if type(output) is list else [output]
22482338

2339+
def _delayed_sampler_outputs(self, model_input):
2340+
next_token_ids = [[DUMMY_TOKEN_ID]] * len(
2341+
model_input.sampling_metadata.seq_groups)
2342+
sampler_output = self._make_decode_output(
2343+
next_token_ids, model_input.sampling_metadata.seq_groups)
2344+
return sampler_output
2345+
22492346
def _decode_sampler_outputs(self, model_input):
22502347
use_async_out_proc = model_input.async_callback is not None
22512348
sampler_outputs = []
@@ -2312,3 +2409,32 @@ def shutdown_inc(self):
23122409

23132410
def __del__(self):
23142411
self.shutdown_inc()
2412+
2413+
def _patch_prev_output(self):
2414+
assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \
2415+
f'''Inputs and outputs are out of sync!
2416+
{len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}'''
2417+
if len(self.cached_step_inputs) == 0:
2418+
return
2419+
model_input = self.cached_step_inputs.pop(0)
2420+
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(
2421+
-1).tolist()
2422+
ctx = model_input.async_callback.keywords["ctx"] # type: ignore
2423+
# If there's no output to patch with, which is usually the case when
2424+
# we're starting a new request after all requests are completed.
2425+
if len(ctx.output_queue) == 0:
2426+
return
2427+
assert len(
2428+
ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
2429+
output_data = ctx.output_queue[0]
2430+
assert len(output_data.outputs) == 1
2431+
for fake_out, real_out in zip(output_data.outputs[0], delayed_output):
2432+
fake_out.samples[0].output_token = real_out
2433+
for sg, real_out in zip(output_data.seq_group_metadata_list,
2434+
delayed_output):
2435+
assert len(sg.seq_data) == 1
2436+
seq_data = list(sg.seq_data.values())[0]
2437+
# This is a hack. Assigning output_token_ids triggers
2438+
# a cache recomputation and we only need to update the last token
2439+
seq_data.output_token_ids_array[-1] = real_out
2440+
seq_data._cached_all_token_ids[-1] = real_out

0 commit comments

Comments
 (0)