Skip to content

[Core][Bugfix] Fix Offline MM Beam Search #16390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
TokensPrompt, to_enc_dec_tuple_list,
zip_enc_dec_prompts)
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import cuda_device_count_stateless, is_list_of
from vllm.utils import cuda_device_count_stateless

logger = init_logger(__name__)

Expand Down Expand Up @@ -469,12 +468,19 @@ def generate_beam_search(
prompts: list[str],
beam_width: int,
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[tuple[list[list[int]], list[str]]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
num_beams=beam_width,
num_return_sequences=beam_width)
num_return_sequences=beam_width,
images=images,
videos=videos,
audios=audios)

for i in range(len(outputs)):
output_ids, output_str = outputs[i]
for j in range(len(output_ids)):
Expand Down Expand Up @@ -936,18 +942,20 @@ def generate_encoder_decoder_greedy_logprobs(

def generate_beam_search(
self,
prompts: Union[list[str], list[list[int]]],
prompts: list[str],
beam_width: int,
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[tuple[list[list[int]], list[str]]]:
if is_list_of(prompts, str, check="all"):
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
else:
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we currently don't use this, so removed it for now to keep things nice looking and consistent the other helpers / get_inputs. Happy to put it back later on if it ends up being needed though

]
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

outputs = self.model.beam_search(
prompts,
inputs,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
returned_outputs = []
for output in outputs:
Expand Down
77 changes: 74 additions & 3 deletions tests/samplers/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"""

import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.assets.audio import AudioAsset


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -48,15 +51,83 @@ def test_beam_search_single_input(
for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for i, (hf_text,
for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{i}-th hf output:")
print(f">>>{j}-th hf output:")
print(hf_text)
print(f">>>{i}-th vllm output:")
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")


@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_passes_multimodal_data(
hf_runner,
vllm_runner,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Ensure that beam search passes multimodal data through correctly."""
# NOTE - this test is primarily to check that mm data is passed to beams
# correctly. As such, we just need to check one extra modality to make
# sure things pass through properly.
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
model = "Qwen/Qwen2-Audio-7B-Instruct"
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
prompts = [
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501
]

with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
audio_token_id = hf_model.config.audio_token_index
hf_outputs = hf_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(
prompts,
beam_width=beam_width,
max_tokens=max_tokens,
audios=audios,
)

seq_with_no_audio_toks = lambda seq: [
tok for tok in seq if tok != audio_token_id
]

for i in range(len(prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]

for j, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
print(hf_text)
print(f">>>{j}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)

for j in range(len(hf_output_ids)):
# Compare everything except for the audio tokens; we do this since
# the IDs returned from the transformers helper expands the audio
# token to match features, while the vLLM helper maintains the
# single audio token in the input text
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
filtered_vllm_output_ids = seq_with_no_audio_toks(
vllm_output_ids[j])

assert filtered_hf_output_ids == filtered_vllm_output_ids
13 changes: 11 additions & 2 deletions vllm/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,18 @@ class BeamSearchOutput:

class BeamSearchInstance:

def __init__(self, prompt_tokens: list[int]):
def __init__(
self,
prompt_tokens: list[int],
logprobs: Optional[list[dict[int, Logprob]]] = None,
**kwargs,
):
self.beams: list[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
BeamSearchSequence(
tokens=prompt_tokens,
logprobs=[] if logprobs is None else list(logprobs),
**kwargs,
)
]
self.completed: list[BeamSearchSequence] = []

Expand Down
32 changes: 18 additions & 14 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,16 +536,6 @@
tokenizer.eos_token_id,
length_penalty)

# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if any("multi_modal_data" in prompt
and prompt["multi_modal_data"] is not None
for prompt in prompts):
logger.warning(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!")

tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
Expand All @@ -556,11 +546,19 @@
instances: list[BeamSearchInstance] = []

for prompt in prompts:
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
if "mm_processor_kwargs" in prompt:
mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"]

if is_token_prompt(prompt):
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(BeamSearchInstance(prompt_tokens))
instances.append(BeamSearchInstance(prompt_tokens, **mm_kwargs))

Check failure on line 561 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "BeamSearchInstance" has incompatible type "**dict[str, Mapping[str, Union[Any, list[Any]]]]"; expected "Optional[list[dict[int, Logprob]]]" [arg-type]

Check failure on line 561 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "BeamSearchInstance" has incompatible type "**dict[str, Mapping[str, Union[Any, list[Any]]]]"; expected "Optional[list[dict[int, Logprob]]]" [arg-type]

Check failure on line 561 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "BeamSearchInstance" has incompatible type "**dict[str, Mapping[str, Union[Any, list[Any]]]]"; expected "Optional[list[dict[int, Logprob]]]" [arg-type]

Check failure on line 561 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 2 to "BeamSearchInstance" has incompatible type "**dict[str, Mapping[str, Union[Any, list[Any]]]]"; expected "Optional[list[dict[int, Logprob]]]" [arg-type]

for _ in range(max_tokens):
all_beams: list[BeamSearchSequence] = list(
Expand All @@ -575,8 +573,11 @@
break

prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,

Check failure on line 578 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types (expression has type "Optional[Mapping[str, Union[Any, list[Any]]]]", TypedDict item "multi_modal_data" has type "Mapping[str, Union[Any, list[Any]]]") [typeddict-item]

Check failure on line 578 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types (expression has type "Optional[Mapping[str, Union[Any, list[Any]]]]", TypedDict item "multi_modal_data" has type "Mapping[str, Union[Any, list[Any]]]") [typeddict-item]

Check failure on line 578 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types (expression has type "Optional[Mapping[str, Union[Any, list[Any]]]]", TypedDict item "multi_modal_data" has type "Mapping[str, Union[Any, list[Any]]]") [typeddict-item]
mm_processor_kwargs=beam.mm_processor_kwargs,

Check failure on line 579 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types (expression has type "Optional[dict[str, Any]]", TypedDict item "mm_processor_kwargs" has type "dict[str, Any]") [typeddict-item]

Check failure on line 579 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types (expression has type "Optional[dict[str, Any]]", TypedDict item "mm_processor_kwargs" has type "dict[str, Any]") [typeddict-item]

Check failure on line 579 in vllm/entrypoints/llm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types (expression has type "Optional[dict[str, Any]]", TypedDict item "mm_processor_kwargs" has type "dict[str, Any]") [typeddict-item]
) for beam in all_beams
]

# only runs for one step
Expand All @@ -602,7 +603,10 @@
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
Expand Down