diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 691ca59b062..f173344344f 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -30,6 +30,7 @@ def create_scheduler( use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, + max_model_len: Optional[int] = None, ) -> Scheduler: '''Create scheduler under test. @@ -44,12 +45,15 @@ def create_scheduler( Returns: :class:`Scheduler` instance ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, + max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, ) model_config = ModelConfig( model=model, @@ -296,6 +300,7 @@ def test_no_mm_input_chunking(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, + max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] requests = create_requests(num_requests=1, diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py new file mode 100644 index 00000000000..f577fb4ab32 --- /dev/null +++ b/tests/v1/spec_decode/test_max_len.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test whether spec decoding handles the max model length properly.""" + +import pytest + +from vllm import LLM, SamplingParams + +_PROMPTS = [ + "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", + "Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501 + "Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501 +] + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_ngram_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="facebook/opt-125m", + max_model_len=100, + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": num_speculative_tokens, + }, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_eagle_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "eagle", + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": num_speculative_tokens, + }, + max_model_len=100, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 5caa4f052fc..50548219fff 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -2,7 +2,7 @@ import numpy as np -from vllm.config import SpeculativeConfig, VllmConfig +from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, _find_subarray_kmp, _kmp_lps_array) @@ -42,14 +42,24 @@ def test_find_subarray_kmp(): def test_ngram_proposer(): def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: - return NgramProposer(vllm_config=VllmConfig( - speculative_config=SpeculativeConfig.from_dict( - { - "prompt_lookup_min": min_n, - "prompt_lookup_max": max_n, - "num_speculative_tokens": k, - "method": "ngram", - }))) + # Dummy model config. Just to set max_model_len. + model_config = ModelConfig(model="facebook/opt-125m", + task="generate", + max_model_len=100, + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + dtype="auto", + seed=None, + trust_remote_code=False) + return NgramProposer( + vllm_config=VllmConfig(model_config=model_config, + speculative_config=SpeculativeConfig. + from_dict({ + "prompt_lookup_min": min_n, + "prompt_lookup_max": max_n, + "num_speculative_tokens": k, + "method": "ngram", + }))) # No match. result = ngram_proposer( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 69e7cc8ee08..16efc42f212 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -185,6 +185,13 @@ def schedule(self) -> SchedulerOutput: num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) + assert num_new_tokens > 0 + # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2322463c071..9505bd7ce43 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,6 +12,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata +PADDING_SLOT_ID = -1 + class EagleProposer: @@ -23,6 +25,7 @@ def __init__( self.vllm_config = vllm_config self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) + self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. @@ -112,22 +115,48 @@ def propose( # Update the inputs. input_ids = draft_token_ids_list[-1] positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + + # Increment the sequence lengths. attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 + # Consider max model length. + attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + self.max_model_len) + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + # Compute the slot mapping. - block_numbers = positions // self.block_size + block_numbers = clamped_positions // self.block_size block_ids = block_table.gather(dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) attn_metadata.slot_mapping = (block_ids * self.block_size + - positions % self.block_size) + clamped_positions % self.block_size) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) # Run the model. with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( input_ids=input_ids, hidden_states=hidden_states, - positions=positions, + positions=clamped_positions, ) logits = self.model.compute_logits(hidden_states, None) draft_token_ids, probs = compute_probs_and_sample_next_token( diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 7e548bb48b5..704153d43a2 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -18,6 +18,9 @@ def __init__(self, vllm_config: VllmConfig): # tokens follow the match, we will return the maximum amount of # tokens until the end. self.k = vllm_config.speculative_config.num_speculative_tokens + # Maximum length of the model. + self.max_model_len = vllm_config.model_config.max_model_len + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose(np.zeros(1024, dtype=np.int32)) @@ -50,9 +53,14 @@ def propose( followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ + # Do not generate draft tokens beyond the max model length. + k = min(self.k, self.max_model_len - context_token_ids.shape[0]) + if k <= 0: + return None + # TODO(woosuk): Optimize this. for n in range(self.max_n, self.min_n - 1, -1): - result = _find_subarray_kmp(context_token_ids, n, self.k) + result = _find_subarray_kmp(context_token_ids, n, k) if result is not None: return result return None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7c88ecc31d0..4cb5a8e171a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1271,7 +1271,8 @@ def generate_draft_token_ids( draft_token_ids.append([]) continue - # Skip requests that require top-p, top-k, etc. + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. req_id = self.input_batch.req_ids[i] if not is_spec_decode_supported(req_id, self.input_batch): draft_token_ids.append([]) @@ -1280,6 +1281,11 @@ def generate_draft_token_ids( # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids + if end_idx >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx])