Skip to content

Commit dcaa929

Browse files
zixi-qidbyoung18
authored andcommitted
[Bugfix] Fix v1/spec_decode/test_ngram.py (vllm-project#16895)
Signed-off-by: qizixi <[email protected]>
1 parent c73af87 commit dcaa929

File tree

2 files changed

+30
-38
lines changed

2 files changed

+30
-38
lines changed

tests/v1/spec_decode/test_ngram.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from vllm.config import SpeculativeConfig, VllmConfig
56
from vllm.v1.spec_decode.ngram_proposer import (NgramProposer,
67
_find_subarray_kmp,
78
_kmp_lps_array)
@@ -39,50 +40,40 @@ def test_find_subarray_kmp():
3940

4041

4142
def test_ngram_proposer():
42-
proposer = NgramProposer()
43+
44+
def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
45+
return NgramProposer(vllm_config=VllmConfig(
46+
speculative_config=SpeculativeConfig.from_dict(
47+
{
48+
"prompt_lookup_min": min_n,
49+
"prompt_lookup_max": max_n,
50+
"num_speculative_tokens": k,
51+
"method": "ngram",
52+
})))
4353

4454
# No match.
45-
result = proposer.propose(
46-
context_token_ids=np.array([1, 2, 3, 4, 5]),
47-
min_n=2,
48-
max_n=2,
49-
k=2,
50-
)
55+
result = ngram_proposer(
56+
2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5]))
5157
assert result is None
5258

5359
# No match for 4-gram.
54-
result = proposer.propose(
55-
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
56-
min_n=4,
57-
max_n=4,
58-
k=2,
59-
)
60+
result = ngram_proposer(
61+
4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
6062
assert result is None
6163

6264
# No match for 4-gram but match for 3-gram.
63-
result = proposer.propose(
64-
context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]),
65-
min_n=3,
66-
max_n=4,
67-
k=2,
68-
)
65+
result = ngram_proposer(
66+
3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]))
6967
assert np.array_equal(result, np.array([4, 1]))
7068

7169
# Match for both 4-gram and 3-gram.
7270
# In this case, the proposer should return the 4-gram match.
73-
result = proposer.propose(
74-
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]),
75-
min_n=3,
76-
max_n=4,
77-
k=2,
78-
)
71+
result = ngram_proposer(3, 4, 2).propose(
72+
context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]))
7973
assert np.array_equal(result, np.array([1, 2])) # Not [5, 1]
8074

8175
# Match for 2-gram and 3-gram, but not 4-gram.
82-
result = proposer.propose(
83-
context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]),
84-
min_n=2,
85-
max_n=4,
86-
k=2,
87-
)
76+
result = ngram_proposer(
77+
2, 4,
78+
2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]))
8879
assert np.array_equal(result, np.array([1, 2])) # Not [5, 2]

vllm/config.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
120120
def pairwise(iterable):
121121
"""
122122
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
123-
123+
124124
Can be removed when Python 3.9 support is dropped.
125125
"""
126126
iterator = iter(iterable)
@@ -266,7 +266,7 @@ class ModelConfig:
266266
config_format: The config format which shall be loaded.
267267
Defaults to 'auto' which defaults to 'hf'.
268268
hf_token: The token to use as HTTP bearer authorization for remote files
269-
. If `True`, will use the token generated when running
269+
. If `True`, will use the token generated when running
270270
`huggingface-cli login` (stored in `~/.huggingface`).
271271
hf_overrides: If a dictionary, contains arguments to be forwarded to the
272272
HuggingFace config. If a callable, it is called to update the
@@ -1624,7 +1624,7 @@ class ParallelConfig:
16241624
"""The full name of the worker class to use. If "auto", the worker class
16251625
will be determined based on the platform."""
16261626
sd_worker_cls: str = "auto"
1627-
"""The full name of the worker class to use for speculative decofing.
1627+
"""The full name of the worker class to use for speculative decofing.
16281628
If "auto", the worker class will be determined based on the platform."""
16291629
worker_extension_cls: str = ""
16301630
"""The full name of the worker extension class to use. The worker extension
@@ -1815,13 +1815,13 @@ class SchedulerConfig:
18151815

18161816
max_num_batched_tokens: int = None # type: ignore
18171817
"""Maximum number of tokens to be processed in a single iteration.
1818-
1818+
18191819
This config has no static default. If left unspecified by the user, it will
18201820
be set in `EngineArgs.create_engine_config` based on the usage context."""
18211821

18221822
max_num_seqs: int = None # type: ignore
18231823
"""Maximum number of sequences to be processed in a single iteration.
1824-
1824+
18251825
This config has no static default. If left unspecified by the user, it will
18261826
be set in `EngineArgs.create_engine_config` based on the usage context."""
18271827

@@ -1867,7 +1867,7 @@ class SchedulerConfig:
18671867
# TODO (ywang96): Make this configurable.
18681868
max_num_encoder_input_tokens: int = field(init=False)
18691869
"""Multimodal encoder compute budget, only used in V1.
1870-
1870+
18711871
NOTE: This is not currently configurable. It will be overridden by
18721872
max_num_batched_tokens in case max multimodal embedding size is larger."""
18731873

@@ -2306,7 +2306,8 @@ def __post_init__(self):
23062306
if self.model is None and self.num_speculative_tokens is not None:
23072307
# TODO(Shangming): Refactor mtp configuration logic when supporting
23082308
# mtp acceleration for more models besides deepseek_v3
2309-
if self.target_model_config.hf_text_config.model_type \
2309+
if self.target_model_config and \
2310+
self.target_model_config.hf_text_config.model_type \
23102311
== "deepseek_v3":
23112312
# use the draft model from the same model:
23122313
self.model = self.target_model_config.model

0 commit comments

Comments
 (0)