Skip to content

Commit da5304f

Browse files
alex-jw-brooksDarkLight1337
authored andcommitted
[Core][Bugfix] Fix Offline MM Beam Search (vllm-project#16390)
Signed-off-by: Alex-Brooks <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 8aa15aa commit da5304f

File tree

4 files changed

+140
-30
lines changed

4 files changed

+140
-30
lines changed

tests/conftest.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@
2929
init_distributed_environment,
3030
initialize_model_parallel)
3131
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
32-
TokensPrompt, to_enc_dec_tuple_list,
33-
zip_enc_dec_prompts)
32+
to_enc_dec_tuple_list, zip_enc_dec_prompts)
3433
from vllm.logger import init_logger
3534
from vllm.outputs import RequestOutput
3635
from vllm.sampling_params import BeamSearchParams
37-
from vllm.utils import cuda_device_count_stateless, is_list_of
36+
from vllm.utils import cuda_device_count_stateless
3837

3938
logger = init_logger(__name__)
4039

@@ -469,12 +468,19 @@ def generate_beam_search(
469468
prompts: list[str],
470469
beam_width: int,
471470
max_tokens: int,
471+
images: Optional[PromptImageInput] = None,
472+
videos: Optional[PromptVideoInput] = None,
473+
audios: Optional[PromptAudioInput] = None,
472474
) -> list[tuple[list[list[int]], list[str]]]:
473475
outputs = self.generate(prompts,
474476
do_sample=False,
475477
max_new_tokens=max_tokens,
476478
num_beams=beam_width,
477-
num_return_sequences=beam_width)
479+
num_return_sequences=beam_width,
480+
images=images,
481+
videos=videos,
482+
audios=audios)
483+
478484
for i in range(len(outputs)):
479485
output_ids, output_str = outputs[i]
480486
for j in range(len(output_ids)):
@@ -936,18 +942,20 @@ def generate_encoder_decoder_greedy_logprobs(
936942

937943
def generate_beam_search(
938944
self,
939-
prompts: Union[list[str], list[list[int]]],
945+
prompts: list[str],
940946
beam_width: int,
941947
max_tokens: int,
948+
images: Optional[PromptImageInput] = None,
949+
videos: Optional[PromptVideoInput] = None,
950+
audios: Optional[PromptAudioInput] = None,
942951
) -> list[tuple[list[list[int]], list[str]]]:
943-
if is_list_of(prompts, str, check="all"):
944-
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
945-
else:
946-
prompts = [
947-
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
948-
]
952+
inputs = self.get_inputs(prompts,
953+
images=images,
954+
videos=videos,
955+
audios=audios)
956+
949957
outputs = self.model.beam_search(
950-
prompts,
958+
inputs,
951959
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
952960
returned_outputs = []
953961
for output in outputs:

tests/samplers/test_beam_search.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
"""
66

77
import pytest
8+
from transformers import AutoModelForSeq2SeqLM
9+
10+
from vllm.assets.audio import AudioAsset
811

912

1013
@pytest.fixture(autouse=True)
@@ -19,6 +22,7 @@ def v1(run_with_both_engines):
1922
# 3. Use the model "huggyllama/llama-7b".
2023
MAX_TOKENS = [64]
2124
BEAM_WIDTHS = [4]
25+
MM_BEAM_WIDTHS = [2]
2226
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
2327

2428

@@ -48,15 +52,90 @@ def test_beam_search_single_input(
4852
for i in range(len(example_prompts)):
4953
hf_output_ids, hf_output_texts = hf_outputs[i]
5054
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
51-
for i, (hf_text,
55+
for j, (hf_text,
5256
vllm_text) in enumerate(zip(hf_output_texts,
5357
vllm_output_texts)):
54-
print(f">>>{i}-th hf output:")
58+
print(f">>>{j}-th hf output:")
5559
print(hf_text)
56-
print(f">>>{i}-th vllm output:")
60+
print(f">>>{j}-th vllm output:")
5761
print(vllm_text)
5862
assert len(hf_output_ids) == len(vllm_output_ids)
5963
for j in range(len(hf_output_ids)):
6064
assert hf_output_ids[j] == vllm_output_ids[j], (
6165
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
6266
f"vLLM: {vllm_output_ids}")
67+
68+
69+
@pytest.mark.parametrize("dtype", ["half"])
70+
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
71+
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
72+
def test_beam_search_passes_multimodal_data(
73+
hf_runner,
74+
vllm_runner,
75+
dtype: str,
76+
max_tokens: int,
77+
beam_width: int,
78+
) -> None:
79+
"""Ensure that beam search passes multimodal data through correctly."""
80+
# NOTE - this test is primarily to check that mm data is passed to beams
81+
# correctly. As such, we just need to check one extra modality to make
82+
# sure things pass through properly.
83+
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
84+
model = "Qwen/Qwen2-Audio-7B-Instruct"
85+
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
86+
prompts = [
87+
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
88+
]
89+
90+
with hf_runner(model, dtype=dtype,
91+
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
92+
audio_token_id = hf_model.config.audio_token_index
93+
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
94+
hf_outputs = hf_model.generate_beam_search(
95+
prompts,
96+
beam_width=beam_width,
97+
max_tokens=max_tokens,
98+
audios=audios,
99+
)
100+
101+
with vllm_runner(model, dtype=dtype) as vllm_model:
102+
vllm_outputs = vllm_model.generate_beam_search(
103+
prompts,
104+
beam_width=beam_width,
105+
max_tokens=max_tokens,
106+
audios=audios,
107+
)
108+
109+
seq_with_no_audio_toks = lambda seq: [
110+
tok for tok in seq if tok != audio_token_id
111+
]
112+
113+
for i in range(len(prompts)):
114+
hf_output_ids, hf_output_texts = hf_outputs[i]
115+
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
116+
117+
for j, (hf_text,
118+
vllm_text) in enumerate(zip(hf_output_texts,
119+
vllm_output_texts)):
120+
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
121+
print(hf_text)
122+
print(f">>>{j}-th vllm output:")
123+
print(vllm_text)
124+
assert len(hf_output_ids) == len(vllm_output_ids)
125+
126+
for j in range(len(hf_output_ids)):
127+
# Compare everything except for the audio tokens; we do this since
128+
# the IDs returned from the transformers helper expands the audio
129+
# token to match features, while the vLLM helper maintains the
130+
# single audio token in the input text
131+
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
132+
filtered_vllm_output_ids = seq_with_no_audio_toks(
133+
vllm_output_ids[j])
134+
135+
# HF output IDs may contain the end of sequence
136+
if len(filtered_hf_output_ids
137+
) == len(filtered_vllm_output_ids) + 1:
138+
assert filtered_hf_output_ids[-1] == eos_token_id
139+
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
140+
141+
assert filtered_hf_output_ids == filtered_vllm_output_ids

vllm/beam_search.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,18 @@ class BeamSearchOutput:
3838

3939
class BeamSearchInstance:
4040

41-
def __init__(self, prompt_tokens: list[int]):
41+
def __init__(
42+
self,
43+
prompt_tokens: list[int],
44+
logprobs: Optional[list[dict[int, Logprob]]] = None,
45+
**kwargs,
46+
):
4247
self.beams: list[BeamSearchSequence] = [
43-
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
48+
BeamSearchSequence(
49+
tokens=prompt_tokens,
50+
logprobs=[] if logprobs is None else list(logprobs),
51+
**kwargs,
52+
)
4453
]
4554
self.completed: list[BeamSearchSequence] = []
4655

vllm/entrypoints/llm.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -536,15 +536,18 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
536536
tokenizer.eos_token_id,
537537
length_penalty)
538538

539-
# TODO - fix handling of multimodal data for beam search; we pass it
540-
# through in the async version on the abstract EngineClient, but not
541-
# here.
542-
if any("multi_modal_data" in prompt
543-
and prompt["multi_modal_data"] is not None
544-
for prompt in prompts):
545-
logger.warning(
546-
"Multimodal data appears to have been provided, but is not"
547-
" currently being passed through in LLM.beam_search()!")
539+
def create_tokens_prompt_from_beam(
540+
beam: BeamSearchSequence) -> TokensPrompt:
541+
token_prompt_kwargs: TokensPrompt = {
542+
"prompt_token_ids": beam.tokens
543+
}
544+
if beam.multi_modal_data is not None:
545+
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
546+
547+
if beam.mm_processor_kwargs is not None:
548+
token_prompt_kwargs[
549+
"mm_processor_kwargs"] = beam.mm_processor_kwargs
550+
return TokensPrompt(**token_prompt_kwargs)
548551

549552
tokenizer = self.get_tokenizer()
550553
# generate 2 * beam_width candidates at each step
@@ -556,11 +559,20 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
556559
instances: list[BeamSearchInstance] = []
557560

558561
for prompt in prompts:
562+
# Add multimodal processor kwargs & data
563+
mm_kwargs = {}
564+
if "multi_modal_data" in prompt:
565+
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
566+
if "mm_processor_kwargs" in prompt:
567+
mm_kwargs["mm_processor_kwargs"] = prompt[
568+
"mm_processor_kwargs"]
569+
559570
if is_token_prompt(prompt):
560571
prompt_tokens = prompt["prompt_token_ids"]
561572
else:
562573
prompt_tokens = tokenizer.encode(prompt["prompt"])
563-
instances.append(BeamSearchInstance(prompt_tokens))
574+
instances.append(
575+
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
564576

565577
for _ in range(max_tokens):
566578
all_beams: list[BeamSearchSequence] = list(
@@ -575,8 +587,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
575587
break
576588

577589
prompts_batch = [
578-
TokensPrompt(prompt_token_ids=beam.tokens)
579-
for beam in all_beams
590+
create_tokens_prompt_from_beam(beam) for beam in all_beams
580591
]
581592

582593
# only runs for one step
@@ -602,7 +613,10 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
602613
tokens=current_beam.tokens + [token_id],
603614
logprobs=current_beam.logprobs + [logprobs],
604615
cum_logprob=current_beam.cum_logprob +
605-
logprob_obj.logprob)
616+
logprob_obj.logprob,
617+
multi_modal_data=current_beam.multi_modal_data,
618+
mm_processor_kwargs=current_beam.
619+
mm_processor_kwargs)
606620

607621
if token_id == tokenizer.eos_token_id and \
608622
not ignore_eos:

0 commit comments

Comments
 (0)