From 2c61c7f8132eb7f2aebc7de99fb430a55e9cf88d Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:11:19 +0200 Subject: [PATCH 1/5] Improve `DecodingConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 32 ++++++++++++++++++++------------ vllm/engine/arg_utils.py | 38 ++++++++++++++++---------------------- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7e2869e4eab..f791229fdd9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, - Optional, Protocol, TypeVar, Union) + Optional, Protocol, TypeVar, Union, get_args) import torch from pydantic import BaseModel, Field, PrivateAttr @@ -3095,15 +3095,28 @@ def get_served_model_name(model: str, return served_model_name +GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", + "xgrammar"] +GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] + + +@config @dataclass class DecodingConfig: - """Dataclass which contains the decoding strategy of the engine""" + """Dataclass which contains the decoding strategy of the engine.""" - # Which guided decoding algo to use. - # 'outlines' / 'lm-format-enforcer' / 'xgrammar' - guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar" + guided_decoding_backend: Union[ + GuidedDecodingBackendV0, + GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar" + """Which engine will be used for guided decoding (JSON schema / regex etc) + by default. With "auto", we will make opinionated choices based on request + contents and what the backend libraries currently support, so the behavior + is subject to change in each release.""" reasoning_backend: Optional[str] = None + """Select the reasoning parser depending on the model that you're using. + This is used to parse the reasoning content into OpenAI API format. + Required for `--enable-reasoning`.""" def compute_hash(self) -> str: """ @@ -3125,17 +3138,12 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - v0_valid_guided_backends = [ - 'outlines', 'lm-format-enforcer', 'xgrammar', 'auto' - ] - v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto'] - backend = GuidedDecodingParams( backend=self.guided_decoding_backend).backend_name if envs.VLLM_USE_V1: - valid_guided_backends = v1_valid_guided_backends + valid_guided_backends = get_args(GuidedDecodingBackendV1) else: - valid_guided_backends = v0_valid_guided_backends + valid_guided_backends = get_args(GuidedDecodingBackendV0) if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}'," f" must be one of {valid_guided_backends}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 85b3ddfce48..c11a4af1951 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -252,7 +252,7 @@ class EngineArgs: additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None - reasoning_parser: Optional[str] = None + reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load def __post_init__(self): @@ -478,18 +478,22 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: 'Examples:\n' '- 1k → 1000\n' '- 1K → 1024\n') - parser.add_argument( + + # Guided decoding arguments + guided_decoding_kwargs = get_kwargs(DecodingConfig) + guided_decoding_group = parser.add_argument_group( + title="DecodingConfig", + description=DecodingConfig.__doc__, + ) + guided_decoding_group.add_argument( '--guided-decoding-backend', - type=str, - default=DecodingConfig.guided_decoding_backend, - help='Which engine will be used for guided decoding' - ' (JSON schema / regex etc) by default. Currently support ' - 'https://github.com/mlc-ai/xgrammar and ' - 'https://github.com/guidance-ai/llguidance.' - 'Valid backend values are "xgrammar", "guidance", and "auto". ' - 'With "auto", we will make opinionated choices based on request ' - 'contents and what the backend libraries currently support, so ' - 'the behavior is subject to change in each release.') + **guided_decoding_kwargs["guided_decoding_backend"]) + guided_decoding_group.add_argument( + "--reasoning-parser", + # This choices is a special case because it's not static + choices=list(ReasoningParserManager.reasoning_parsers), + **guided_decoding_kwargs["reasoning_backend"]) + parser.add_argument( '--logits-processor-pattern', type=optional_str, @@ -1017,16 +1021,6 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: "If enabled, the model will be able to generate reasoning content." ) - parser.add_argument( - "--reasoning-parser", - type=str, - choices=list(ReasoningParserManager.reasoning_parsers), - default=None, - help= - "Select the reasoning parser depending on the model that you're " - "using. This is used to parse the reasoning content into OpenAI " - "API format. Required for ``--enable-reasoning``.") - parser.add_argument( "--disable-cascade-attn", action="store_true", From b07005bf05cb841fc89606e7b44d94dab1dbc9c5 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:15:46 +0200 Subject: [PATCH 2/5] Improve `PoolerConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/config.py b/vllm/config.py index f791229fdd9..a8bf8a1a11a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2771,6 +2771,7 @@ def get_limit_per_prompt(self, modality: str) -> int: # TODO: Add configs to init vision tower or not. +@config @dataclass class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" From 256eb4862f75cc934984a745200dbd3d56f3a27c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:28:17 +0200 Subject: [PATCH 3/5] Improve `MultiModalConfig` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/source/models/supported_models.md | 2 +- docs/source/serving/multimodal_inputs.md | 2 +- examples/offline_inference/mistral-small.py | 4 +-- ...i_chat_completion_client_for_multimodal.py | 2 +- tests/engine/test_arg_utils.py | 4 +++ tests/entrypoints/openai/test_audio.py | 2 +- tests/entrypoints/openai/test_video.py | 2 +- tests/entrypoints/openai/test_vision.py | 2 +- .../openai/test_vision_embedding.py | 2 +- .../audio_language/test_ultravox.py | 2 +- vllm/config.py | 16 ++++----- vllm/engine/arg_utils.py | 34 +++++++++---------- vllm/multimodal/processing.py | 5 +-- vllm/multimodal/registry.py | 6 ++-- 14 files changed, 43 insertions(+), 42 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 34917b5bfef..4df9c511ca3 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -788,7 +788,7 @@ llm = LLM( Online serving: ```bash -vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4 +vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}' ``` **This is no longer required if you are using vLLM V1.** diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index f45d36c3cca..d9a093e8d14 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server: ```bash vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ - --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' ``` Then, you can use the OpenAI client as follows: diff --git a/examples/offline_inference/mistral-small.py b/examples/offline_inference/mistral-small.py index 9bb66fdbc45..af1831bd36d 100644 --- a/examples/offline_inference/mistral-small.py +++ b/examples/offline_inference/mistral-small.py @@ -16,11 +16,11 @@ # # Mistral format # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ # --tokenizer-mode mistral --config-format mistral --load-format mistral \ -# --limit-mm-per-prompt 'image=4' --max-model-len 16384 +# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384 # # # HF format # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ -# --limit-mm-per-prompt 'image=4' --max-model-len 16384 +# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384 # ``` # # - Client: diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index 18006e2c423..70db4d95e64 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -9,7 +9,7 @@ (multi-image inference with Phi-3.5-vision-instruct) vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ - --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' (audio inference with Ultravox) vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 92387b46425..0203db877b2 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -24,6 +24,10 @@ }), ]) def test_limit_mm_per_prompt_parser(arg, expected): + """This functionality is deprecated and will be removed in the future. + This argument should be passed as JSON string instead. + + TODO: Remove with nullable_kvs.""" parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index b13002a5b68..e28a3622084 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -27,7 +27,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"audio={MAXIMUM_AUDIOS}", + f'{{"audio":{MAXIMUM_AUDIOS}}}', ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index f9ccce9c1c3..0a802651f5a 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -31,7 +31,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"video={MAXIMUM_VIDEOS}", + f'{{"video":{MAXIMUM_VIDEOS}}}', ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 4b9029ded41..c1894f3af01 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -35,7 +35,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}", + f'{{"image":{MAXIMUM_IMAGES}}}', ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 3e6f13e10ac..a0b19b2a094 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -37,7 +37,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}", + f'{{"image":{MAXIMUM_IMAGES}}}', "--chat-template", str(vlm2vec_jinja_path), ] diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index a843e41aa26..9f54ed9770d 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -49,7 +49,7 @@ def audio(request): def server(request, audio_assets): args = [ "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", - f"--limit-mm-per-prompt=audio={len(audio_assets)}", + f'--limit-mm-per-prompt={{"audio":{len(audio_assets)}}}', "--trust-remote-code" ] + [ f"--{key.replace('_','-')}={value}" diff --git a/vllm/config.py b/vllm/config.py index a8bf8a1a11a..8450e79cfa2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2725,6 +2725,7 @@ def verify_with_model_config(self, model_config: ModelConfig): self.prompt_adapter_dtype) +@config @dataclass class MultiModalConfig: """Controls the behavior of multimodal models.""" @@ -2732,6 +2733,8 @@ class MultiModalConfig: limit_per_prompt: Mapping[str, int] = field(default_factory=dict) """ The maximum number of input items allowed per prompt for each modality. + This should be a JSON string that will be parsed into a dictionary. + Defaults to 1 (V0) or 999 (V1) for each modality. """ def compute_hash(self) -> str: @@ -2753,20 +2756,15 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - def get_default_limit_per_prompt(self) -> int: - """ - Return the default number of input items allowed per prompt - for any modality if not specified by the user. - """ - return 999 if envs.VLLM_USE_V1 else 1 - def get_limit_per_prompt(self, modality: str) -> int: """ Get the maximum number of input items allowed per prompt for the given modality. """ - default = self.get_default_limit_per_prompt() - return self.limit_per_prompt.get(modality, default) + return self.limit_per_prompt.get( + modality, + 999 if envs.VLLM_USE_V1 else 1, + ) # TODO: Add configs to init vision tower or not. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c11a4af1951..112d83ce47d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -20,11 +20,12 @@ DecodingConfig, Device, DeviceConfig, DistributedExecutorBackend, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ModelImpl, ObservabilityConfig, - ParallelConfig, PoolerConfig, PoolType, - PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, TaskOption, TokenizerPoolConfig, - VllmConfig, get_attr_docs, get_field) + ModelConfig, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PoolType, PromptAdapterConfig, SchedulerConfig, + SchedulerPolicy, SpeculativeConfig, TaskOption, + TokenizerPoolConfig, VllmConfig, get_attr_docs, + get_field) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -190,7 +191,8 @@ class EngineArgs: TokenizerPoolConfig.pool_type tokenizer_pool_extra_config: dict[str, Any] = \ get_field(TokenizerPoolConfig, "extra_config") - limit_mm_per_prompt: Optional[Mapping[str, int]] = None + limit_mm_per_prompt: Mapping[str, int] = \ + get_field(MultiModalConfig, "limit_per_prompt") mm_processor_kwargs: Optional[Dict[str, Any]] = None disable_mm_preprocessor_cache: bool = False enable_lora: bool = False @@ -701,18 +703,14 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: **tokenizer_kwargs["extra_config"]) # Multimodal related configs - parser.add_argument( - '--limit-mm-per-prompt', - type=nullable_kvs, - default=EngineArgs.limit_mm_per_prompt, - # The default value is given in - # MultiModalConfig.get_default_limit_per_prompt - help=('For each multimodal plugin, limit how many ' - 'input instances to allow for each prompt. ' - 'Expects a comma-separated list of items, ' - 'e.g.: `image=16,video=2` allows a maximum of 16 ' - 'images and 2 videos per prompt. Defaults to ' - '1 (V0) or 999 (V1) for each modality.')) + multimodal_kwargs = get_kwargs(MultiModalConfig) + multimodal_group = parser.add_argument_group( + title="MultiModalConfig", + description=MultiModalConfig.__doc__, + ) + multimodal_group.add_argument('--limit-mm-per-prompt', + **multimodal_kwargs["limit_per_prompt"]) + parser.add_argument( '--mm-processor-kwargs', default=None, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 7f289426d34..c1980f61b50 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1117,8 +1117,9 @@ def _to_mm_items( if num_items > allowed_limit: raise ValueError( - f"You set or defaulted to {modality}={allowed_limit} " - f"in --limit-mm-per-prompt`, but passed {num_items} " + "You set or defaulted to " + f'{{"{modality}":{allowed_limit}}} in ' + f"`--limit-mm-per-prompt`, but passed {num_items} " f"{modality} items in the same prompt.") return mm_items diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index def0595013b..b8a1eac5357 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -194,9 +194,9 @@ def map_input( max_items = self._limits_by_model[model_config][data_key] if num_items > max_items: raise ValueError( - f"You set {data_key}={max_items} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but found {num_items} items " - "in the same prompt.") + f'You set {{"{data_key}":{max_items}}} (or defaulted to ' + f"1) in `--limit-mm-per-prompt`, but found {num_items} " + "items in the same prompt.") input_dict = plugin.map_input(model_config, data_value, mm_processor_kwargs) From 8f33dc97a5b213c71b200c4cf84c446fb8ad7999 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 17 Apr 2025 17:57:25 +0200 Subject: [PATCH 4/5] Review comments Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/entrypoints/openai/test_audio.py | 2 +- tests/entrypoints/openai/test_video.py | 2 +- tests/entrypoints/openai/test_vision.py | 2 +- tests/entrypoints/openai/test_vision_embedding.py | 2 +- tests/models/decoder_only/audio_language/test_ultravox.py | 4 ++-- vllm/config.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index e28a3622084..a0a3215d67b 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -27,7 +27,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f'{{"audio":{MAXIMUM_AUDIOS}}}', + str({"audio": MAXIMUM_AUDIOS}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 0a802651f5a..263842b94a7 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -31,7 +31,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f'{{"video":{MAXIMUM_VIDEOS}}}', + str({"video": MAXIMUM_VIDEOS}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index c1894f3af01..4aeb1700ba9 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -35,7 +35,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f'{{"image":{MAXIMUM_IMAGES}}}', + str({"image": MAXIMUM_IMAGES}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index a0b19b2a094..b1b24d8029b 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -37,7 +37,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f'{{"image":{MAXIMUM_IMAGES}}}', + str({"image": MAXIMUM_IMAGES}), "--chat-template", str(vlm2vec_jinja_path), ] diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 9f54ed9770d..28e07aecb98 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -47,10 +47,10 @@ def audio(request): pytest.param(CHUNKED_PREFILL_KWARGS), ]) def server(request, audio_assets): + limit_mm_per_prompt = {"audio": len(audio_assets)} args = [ "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", - f'--limit-mm-per-prompt={{"audio":{len(audio_assets)}}}', - "--trust-remote-code" + f"--limit-mm-per-prompt='{limit_mm_per_prompt}'", "--trust-remote-code" ] + [ f"--{key.replace('_','-')}={value}" for key, value in request.param.items() diff --git a/vllm/config.py b/vllm/config.py index 8450e79cfa2..db43d790c53 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2734,7 +2734,7 @@ class MultiModalConfig: """ The maximum number of input items allowed per prompt for each modality. This should be a JSON string that will be parsed into a dictionary. - Defaults to 1 (V0) or 999 (V1) for each modality. + Defaults to 1 (V0) or 999 (V1) for each modality. """ def compute_hash(self) -> str: From 206d69a88680863c2e011daeb49f464bd6d7f0ce Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 17 Apr 2025 18:21:01 +0200 Subject: [PATCH 5/5] Review comments Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/decoder_only/audio_language/test_ultravox.py | 6 +++--- vllm/multimodal/processing.py | 3 ++- vllm/multimodal/registry.py | 7 ++++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 28e07aecb98..3d058d1bca5 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -47,10 +47,10 @@ def audio(request): pytest.param(CHUNKED_PREFILL_KWARGS), ]) def server(request, audio_assets): - limit_mm_per_prompt = {"audio": len(audio_assets)} args = [ - "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", - f"--limit-mm-per-prompt='{limit_mm_per_prompt}'", "--trust-remote-code" + "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", + "--limit-mm-per-prompt", + str({"audio": len(audio_assets)}), "--trust-remote-code" ] + [ f"--{key.replace('_','-')}={value}" for key, value in request.param.items() diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c1980f61b50..16358d1a5ee 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import json import re import sys from abc import ABC, abstractmethod @@ -1118,7 +1119,7 @@ def _to_mm_items( if num_items > allowed_limit: raise ValueError( "You set or defaulted to " - f'{{"{modality}":{allowed_limit}}} in ' + f"'{json.dumps({modality: allowed_limit})}' in " f"`--limit-mm-per-prompt`, but passed {num_items} " f"{modality} items in the same prompt.") diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b8a1eac5357..5c687e49d22 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import json from collections import UserDict from collections.abc import Mapping, Sequence from dataclasses import dataclass @@ -194,9 +195,9 @@ def map_input( max_items = self._limits_by_model[model_config][data_key] if num_items > max_items: raise ValueError( - f'You set {{"{data_key}":{max_items}}} (or defaulted to ' - f"1) in `--limit-mm-per-prompt`, but found {num_items} " - "items in the same prompt.") + f"You set '{json.dumps({data_key: max_items})}' (or " + "defaulted to 1) in `--limit-mm-per-prompt`, but found " + f"{num_items} items in the same prompt.") input_dict = plugin.map_input(model_config, data_value, mm_processor_kwargs)