Skip to content

Commit 9dc0313

Browse files
xuechendimadamczyk-intel
authored andcommitted
[INTEL-HPU][v0] Port delayed sampling to upstream (vllm-project#16949)
Signed-off-by: Michal Adamczyk <[email protected]> Signed-off-by: Chendi Xue <[email protected]> Co-authored-by: Michal Adamczyk <[email protected]> Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 3fcb4e3 commit 9dc0313

File tree

2 files changed

+140
-7
lines changed

2 files changed

+140
-7
lines changed

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
VLLM_RAY_BUNDLE_INDICES: str = ""
9999
VLLM_CUDART_SO_PATH: Optional[str] = None
100100
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
101+
VLLM_HPU_USE_DELAYED_SAMPLING: bool = False
101102
VLLM_DP_RANK: int = 0
102103
VLLM_DP_RANK_LOCAL: int = -1
103104
VLLM_DP_SIZE: int = 1
@@ -650,6 +651,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
650651
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
651652
("1", "true"),
652653

654+
# Use delayed sampling for HPU to reduce host cpu overhead
655+
# between each step.
656+
"VLLM_HPU_USE_DELAYED_SAMPLING":
657+
lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in
658+
("1", "true"),
659+
653660
# Rank of the process in the data parallel setting
654661
"VLLM_DP_RANK":
655662
lambda: int(os.getenv("VLLM_DP_RANK", "0")),

vllm/worker/hpu_model_runner.py

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575

7676
LORA_WARMUP_RANK = 8
7777

78+
DUMMY_TOKEN_ID = -1
79+
7880

7981
class PhaseType(Enum):
8082
PREFILL = 'prefill'
@@ -701,6 +703,9 @@ def __init__(
701703

702704
# For multi-step scheduling
703705
self.cached_step_outputs: List[torch.Tensor] = []
706+
# For delayed sampling
707+
self.cached_step_inputs: List[
708+
ModelInputForHPUWithSamplingMetadata] = []
704709

705710
def _set_gc_threshold(self) -> None:
706711
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -804,6 +809,12 @@ def load_model(self) -> None:
804809
msg = f"Loading model weights took in total {m.get_summary_string()}"
805810
logger.info(msg)
806811

812+
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
813+
return htorch.hpu.wrap_in_hpu_graph(
814+
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
815+
) if htorch.utils.internal.is_lazy() else HpuModelAdapter(
816+
*args, **kwargs)
817+
807818
def get_model(self) -> nn.Module:
808819
return self.model
809820

@@ -2115,6 +2126,21 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
21152126

21162127
return lora_mask, lora_logits_mask
21172128

2129+
def _get_seq_ids(self, model_input):
2130+
return ([
2131+
sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups
2132+
])
2133+
2134+
def _pad_to_max_num_seqs(self, tensor, value):
2135+
padding_needed = self.max_num_seqs - tensor.size(0)
2136+
if padding_needed:
2137+
padding = torch.full((padding_needed, *tensor.shape[1:]),
2138+
value,
2139+
device=tensor.device,
2140+
dtype=tensor.dtype)
2141+
tensor = torch.cat([tensor, padding])
2142+
return tensor
2143+
21182144
@torch.inference_mode()
21192145
def execute_model(
21202146
self,
@@ -2125,6 +2151,37 @@ def execute_model(
21252151
warmup_mode=False,
21262152
seqs=None,
21272153
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
2154+
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
2155+
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
2156+
assert not (use_delayed_sampling and num_steps != 1), \
2157+
'Delayed sampling is not compatible with MSS!'
2158+
assert model_input.input_tokens is not None
2159+
if use_delayed_sampling and not model_input.is_prompt and \
2160+
self.is_driver_worker:
2161+
num_cached = len(self.cached_step_outputs)
2162+
assert num_cached > 0
2163+
cur_seq_ids = self._get_seq_ids(model_input)
2164+
cur_seq_id_pos = {
2165+
sid: idx
2166+
for idx, sid in enumerate(cur_seq_ids) if sid >= 0
2167+
}
2168+
htorch.core.mark_step()
2169+
for i in range(num_cached):
2170+
prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i])
2171+
target_indices = [
2172+
cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids
2173+
]
2174+
padding = self.cached_step_outputs[i].size(0) - len(
2175+
target_indices)
2176+
target_indices.extend([-1] * padding)
2177+
target_indices = torch.tensor(
2178+
target_indices,
2179+
device=model_input.input_tokens.device,
2180+
dtype=model_input.input_tokens.dtype)
2181+
model_input.input_tokens.index_copy_(
2182+
0, target_indices, self.cached_step_outputs[i])
2183+
htorch.core.mark_step()
2184+
21282185
if not model_input.is_first_multi_step:
21292186
if not model_input.is_last_step:
21302187
# not first or last multi-step
@@ -2140,7 +2197,21 @@ def execute_model(
21402197
assert model_input.lora_mapping is not None
21412198
self.set_active_loras(model_input.lora_requests,
21422199
model_input.lora_mapping)
2143-
input_tokens = model_input.input_tokens
2200+
# Rank!=0 workers has is_prompt==None
2201+
if use_delayed_sampling and not model_input.is_prompt and \
2202+
model_input.input_tokens.size(1) == 1:
2203+
if self.is_driver_worker:
2204+
model_kwargs_broadcast_data = {
2205+
"input_tokens": model_input.input_tokens
2206+
}
2207+
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
2208+
input_tokens = model_input.input_tokens
2209+
2210+
else:
2211+
model_kwargs_broadcast_data = broadcast_tensor_dict(src=0)
2212+
input_tokens = model_kwargs_broadcast_data["input_tokens"]
2213+
else:
2214+
input_tokens = model_input.input_tokens
21442215
input_positions = model_input.input_positions
21452216
attn_metadata = model_input.attn_metadata
21462217
sampling_metadata = model_input.sampling_metadata
@@ -2187,7 +2258,7 @@ def execute_model(
21872258
f"graphs{'T' if use_graphs else 'F'}")
21882259
else:
21892260
model_event_name = 'model_executable'
2190-
if num_steps > 1:
2261+
if num_steps > 1 or use_delayed_sampling:
21912262
# in case of multi-step scheduling
21922263
# we only want to pythonize in the last step
21932264
sampling_metadata.skip_sampler_cpu_output = True
@@ -2247,9 +2318,9 @@ def try_revert_dummy_output_tokens():
22472318
if not self.is_driver_worker:
22482319
continue
22492320

2250-
if model_input.async_callback is not None:
2251-
model_input.async_callback()
2252-
# Sample the next token.
2321+
if use_delayed_sampling:
2322+
fake_output = self._delayed_sampler_outputs(model_input)
2323+
22532324
with self.profiler.record_event(
22542325
'internal', ('sample_'
22552326
f'{"prompt" if is_prompt else "decode"}_'
@@ -2261,9 +2332,16 @@ def try_revert_dummy_output_tokens():
22612332
)
22622333
if num_steps > 1:
22632334
output = output.sampled_token_ids
2264-
self.cached_step_outputs.append(
2265-
output.detach().clone())
2335+
self.cached_step_outputs.append(output)
2336+
if use_delayed_sampling and self.is_driver_worker:
2337+
self._patch_prev_output()
2338+
output = self._pad_to_max_num_seqs(
2339+
output.sampled_token_ids, DUMMY_TOKEN_ID)
2340+
self.cached_step_outputs.append(output)
2341+
self.cached_step_inputs.append(model_input)
22662342
htorch.core.mark_step()
2343+
if model_input.async_callback is not None:
2344+
model_input.async_callback()
22672345
if i < num_steps - 1:
22682346
if i == 0:
22692347
if model_input.async_callback is not None:
@@ -2336,11 +2414,30 @@ def try_revert_dummy_output_tokens():
23362414
is_prompt=is_prompt)
23372415
self.profiler.record_counter(self.event_start, counters)
23382416
if num_steps == 1:
2417+
if self.return_hidden_states:
2418+
# we only need to pass hidden states of most recent token
2419+
assert model_input.sampling_metadata is not None
2420+
if model_input.is_prompt:
2421+
output.prefill_hidden_states = hidden_states
2422+
output.hidden_states = hidden_states
2423+
if use_delayed_sampling:
2424+
if self.is_driver_worker:
2425+
return [fake_output]
2426+
else:
2427+
return []
2428+
23392429
return [output] if self.is_driver_worker else []
23402430
else:
23412431
return []
23422432
return output if type(output) is list else [output]
23432433

2434+
def _delayed_sampler_outputs(self, model_input):
2435+
next_token_ids = [[DUMMY_TOKEN_ID]] * len(
2436+
model_input.sampling_metadata.seq_groups)
2437+
sampler_output = self._make_decode_output(
2438+
next_token_ids, model_input.sampling_metadata.seq_groups)
2439+
return sampler_output
2440+
23442441
def _decode_sampler_outputs(self, model_input):
23452442
use_async_out_proc = model_input.async_callback is not None
23462443
sampler_outputs = []
@@ -2407,3 +2504,32 @@ def shutdown_inc(self):
24072504

24082505
def __del__(self):
24092506
self.shutdown_inc()
2507+
2508+
def _patch_prev_output(self):
2509+
assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \
2510+
f'''Inputs and outputs are out of sync!
2511+
{len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}'''
2512+
if len(self.cached_step_inputs) == 0:
2513+
return
2514+
model_input = self.cached_step_inputs.pop(0)
2515+
delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze(
2516+
-1).tolist()
2517+
ctx = model_input.async_callback.keywords["ctx"] # type: ignore
2518+
# If there's no output to patch with, which is usually the case when
2519+
# we're starting a new request after all requests are completed.
2520+
if len(ctx.output_queue) == 0:
2521+
return
2522+
assert len(
2523+
ctx.output_queue) == 1, 'There should be exactly 1 output waiting!'
2524+
output_data = ctx.output_queue[0]
2525+
assert len(output_data.outputs) == 1
2526+
for fake_out, real_out in zip(output_data.outputs[0], delayed_output):
2527+
fake_out.samples[0].output_token = real_out
2528+
for sg, real_out in zip(output_data.seq_group_metadata_list,
2529+
delayed_output):
2530+
assert len(sg.seq_data) == 1
2531+
seq_data = list(sg.seq_data.values())[0]
2532+
# This is a hack. Assigning output_token_ids triggers
2533+
# a cache recomputation and we only need to update the last token
2534+
seq_data.output_token_ids_array[-1] = real_out
2535+
seq_data._cached_all_token_ids[-1] = real_out

0 commit comments

Comments
 (0)