Skip to content

[Bugfix] Fix the issue where llm.generate cannot be called repeatedly after setting GuidedDecodingParams #16767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ def backend_options(self) -> list[str]:
return []
return self.backend.split(":")[1].split(",")

def add_option(self, opt_name: str) -> None:
"""Adds an option to the backend options."""
if not self.backend:
self.backend = f":{opt_name}"
elif ":" not in self.backend:
self.backend += f":{opt_name}"
else:
options = set(self.backend_options())
options.add(opt_name)
self.backend = f"{self.backend_name}:{','.join(sorted(options))}"

def no_fallback(self) -> bool:
"""Returns True if the "no-fallback" option is supplied for the guided
decoding backend"""
Expand Down
11 changes: 10 additions & 1 deletion vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,14 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
raise ValueError(f"Only {supported_backends} structured output is "
"supported in V1.")
if params.guided_decoding.backend:
if params.guided_decoding.backend != engine_level_backend:
# Request-level backend selection is not supported in V1.
# The values may differ if `params` is reused and was set
# to a specific backend based on `auto` behavior in a previous
# request. We remember that it was set as a result of `auto`
# using the `_auto` option set on the backend in the params.
if (params.guided_decoding.backend != engine_level_backend
and not (engine_level_backend == "auto" and "_auto"
in params.guided_decoding.backend_options())):
raise ValueError(
"Request-level structured output backend selection is no "
"longer supported. The request specified "
Expand Down Expand Up @@ -182,6 +189,8 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
# The request includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
params.guided_decoding.backend = "guidance"
# Remember that this backend was set automatically
params.guided_decoding.add_option("_auto")

if engine_level_backend.startswith("guidance"):
# TODO ideally we would have the LLTokenizer here as Lark syntax
Expand Down