Skip to content

Improve configs - SpeculativeConfig #16971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 73 additions & 99 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,139 +2125,113 @@ def __post_init__(self):
self.device = torch.device(self.device_type)


SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]


@config
@dataclass
class SpeculativeConfig:
"""
Configuration for speculative decoding.
Configurable parameters include:
- General Speculative Decoding Control:
- num_speculative_tokens (int): The number of speculative
tokens, if provided. It will default to the number in the draft
model config if present, otherwise, it is required.
- model (Optional[str]): The name of the draft model, eagle head,
or additional weights, if provided.
- method (Optional[str]): The name of the speculative method to use.
If users provide and set the `model` param, the speculative method
type will be detected automatically if possible, if `model` param
is not provided, the method name must be provided.
- Possible values:
- ngram
Related additional configuration:
- prompt_lookup_max (Optional[int]):
Maximum size of ngram token window when using Ngram
proposer, required when method is set to ngram.
- prompt_lookup_min (Optional[int]):
Minimum size of ngram token window when using Ngram
proposer, if provided. Defaults to 1.
- eagle
- medusa
- mlp_speculator
- draft_model
- acceptance_method (str): The method to use for accepting draft
tokens. This can take two possible values: 'rejection_sampler' and
'typical_acceptance_sampler' for RejectionSampler and
TypicalAcceptanceSampler respectively. If not specified, it
defaults to 'rejection_sampler'.
- Possible values:
- rejection_sampler
- typical_acceptance_sampler
Related additional configuration:
- posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the
posterior probability of a token in the target model
for it to be accepted. This threshold is used only
when we use the TypicalAcceptanceSampler for token
acceptance.
- posterior_alpha (Optional[float]):
Scaling factor for entropy-based threshold, applied
when using TypicalAcceptanceSampler.
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
parallelism for the draft model. Can only be 1 or the same as the
target model's tensor parallel size.
- disable_logprobs (bool): If set to True, token log probabilities are
not returned during speculative decoding. If set to False, token
log probabilities are returned according to the log probability
settings in SamplingParams. If not specified, it defaults to True.

- Draft Model Configuration:
- quantization (Optional[str]): Quantization method that was used to
quantize the draft model weights. If None, we assume the
model weights are not quantized. Note that it only takes effect
when using the draft model-based speculative method.
- max_model_len (Optional[int]): The maximum model length of the
draft model. Used when testing the ability to skip
speculation for some sequences.
- revision: The specific model version to use for the draft model. It
can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version.
- code_revision: The specific revision to use for the draft model code
on Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
"""Configuration for speculative decoding."""

- Advanced Control:
- disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
batch expansion for scoring proposals. If not specified, it
defaults to False.
- disable_by_batch_size (Optional[int]): Disable speculative decoding
for new incoming requests when the number of enqueued requests is
larger than this value, if provided.

Although the parameters above are structured hierarchically, there is no
need to nest them during configuration.

Non-configurable internal parameters include:
- Model Configuration:
- target_model_config (ModelConfig): The configuration of the target
model.
- draft_model_config (ModelConfig): The configuration of the draft
model initialized internal.
- Parallelism Configuration:
- target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
- draft_parallel_config (ParallelConfig): The parallel configuration
for the draft model initialized internal.
- Execution Control:
- enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since it's not
yet compatible with speculative decode.
- disable_log_stats (bool): Whether to disable the periodic printing of
stage times in speculative decoding.
"""
# speculative configs from cli args
# General speculative decoding control
num_speculative_tokens: int = field(default=None,
init=True) # type: ignore
method: Optional[str] = None
acceptance_method: str = "rejection_sampler"
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None
"""The name of the draft model, eagle head, or additional weights, if
provided."""
method: Optional[SpeculativeMethod] = None
"""The name of the speculative method to use. If users provide and set the
`model` param, the speculative method type will be detected automatically
if possible, if `model` param is not provided, the method name must be
provided.

If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler"
"""The method to use for accepting draft tokens:\n
- "rejection_sampler" maps to `RejectionSampler`.\n
- "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`.

If using `typical_acceptance_sampler`, the related configuration
`posterior_threshold` and `posterior_alpha` should be considered."""
draft_tensor_parallel_size: Optional[int] = None
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
"""If set to True, token log probabilities are not returned during
speculative decoding. If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams."""

model: Optional[str] = None
# Draft model configuration
quantization: Optional[str] = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: Optional[int] = None
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: Optional[str] = None
"""The specific model version to use for the draft model. It can be a
branch name, a tag name, or a commit id. If unspecified, will use the
default version."""
code_revision: Optional[str] = None
"""The specific revision to use for the draft model code on Hugging Face
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version."""

# Advanced control
disable_mqa_scorer: bool = False
"""Disable the MQA scorer and fall back to batch expansion for scoring
proposals."""
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""

# Ngram proposer configuration
prompt_lookup_max: Optional[int] = None
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: Optional[int] = None
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""

# Typical acceptance sampler configuration
posterior_threshold: Optional[float] = None
"""A threshold value that sets a lower bound on the posterior probability
of a token in the target model for it to be accepted. This threshold is
used only when we use the `TypicalAcceptanceSampler` for token acceptance.
"""
posterior_alpha: Optional[float] = None
"""Scaling factor for entropy-based threshold, applied when using
`TypicalAcceptanceSampler`."""

# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
"""The configuration of the target model."""
target_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: bool = field(default=None,
init=True) # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: bool = field(default=None, init=True) # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""

# params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
"""The parallel configuration for the draft model initialized internal."""

def compute_hash(self) -> str:
"""
Expand Down
17 changes: 12 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,11 +766,18 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]:
help=('Maximum number of forward steps per '
'scheduler call.'))

parser.add_argument('--speculative-config',
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
# Speculative arguments
speculative_group = parser.add_argument_group(
title="SpeculativeConfig",
description=SpeculativeConfig.__doc__,
)
speculative_group.add_argument(
'--speculative-config',
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')

parser.add_argument(
'--ignore-patterns',
action="append",
Expand Down