From 131e6797f25cd33150366675a1c2e72f6fbc9d58 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Wed, 9 Apr 2025 23:20:59 +0000 Subject: [PATCH 01/15] add request level, per-step acceptance counts tracking for spec dec Signed-off-by: Bryan Lu --- vllm/outputs.py | 4 ++++ vllm/v1/core/sched/scheduler.py | 8 +++++++- vllm/v1/engine/__init__.py | 2 ++ vllm/v1/engine/llm_engine.py | 19 ++++++++++++++----- vllm/v1/engine/output_processor.py | 26 ++++++++++++++++---------- vllm/v1/engine/processor.py | 3 ++- vllm/v1/request.py | 4 +++- 7 files changed, 48 insertions(+), 18 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d882..661dfe2870a 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -43,6 +43,7 @@ class CompletionOutput: finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None lora_request: Optional[LoRARequest] = None + spec_token_acceptance_counts: Optional[list[int]] = None def finished(self) -> bool: return self.finish_reason is not None @@ -133,6 +134,9 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.spec_token_acceptance_counts = [ + o.spec_token_acceptance_counts for o in outputs + ] def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 488d32cb82c..ab4b88eed22 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -601,6 +601,10 @@ def update_from_output( num_draft_tokens=len(scheduled_spec_token_ids), num_accepted_tokens=len(generated_token_ids) - 1) + for i in range(len(generated_token_ids)): + if request.spec_token_acceptance_counts is not None: + request.spec_token_acceptance_counts[i] += 1 + cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) # OPTIMIZATION: Avoid list(set) if the set is empty. @@ -662,7 +666,9 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + spec_token_acceptance_counts=request. + spec_token_acceptance_counts)) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 1264e43c79d..fa7a66951ea 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -60,6 +60,7 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] + spec_token_acceptance_counts: Optional[list[int]] class EngineCoreEventType(enum.IntEnum): @@ -102,6 +103,7 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + spec_token_acceptance_counts: Optional[list[int]] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4c67186f704..345ea9036eb 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -183,11 +183,20 @@ def add_request( priority: int = 0, ) -> None: # Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + num_spec_tokens = 0 + if self.vllm_config.speculative_config is not None: + num_spec_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + trace_headers, + prompt_adapter_request, + priority, + num_spec_tokens=num_spec_tokens) n = params.n if isinstance(params, SamplingParams) else 1 diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 70f072d3c93..b040f42f8e7 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -136,10 +136,9 @@ def from_new_request( ) def make_request_output( - self, - new_token_ids: list[int], - finish_reason: Optional[FinishReason], + self, new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]] ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -150,7 +149,10 @@ def make_request_output( return None completion_output = self._new_completion_output( - new_token_ids, finish_reason, stop_reason) + new_token_ids, + finish_reason, + stop_reason, + spec_token_acceptance_counts=spec_token_acceptance_counts) request_id = self.request_id if self.parent_req is None: @@ -186,10 +188,9 @@ def _new_request_output( ) def _new_completion_output( - self, - token_ids: list[int], - finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], + self, token_ids: list[int], finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]] ) -> CompletionOutput: finished = finish_reason is not None @@ -212,7 +213,8 @@ def _new_completion_output( logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, - stop_reason=stop_reason if finished else None) + stop_reason=stop_reason if finished else None, + spec_token_acceptance_counts=spec_token_acceptance_counts) class OutputProcessor: @@ -337,7 +339,11 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, + finish_reason, + stop_reason, + spec_token_acceptance_counts=engine_core_output. + spec_token_acceptance_counts): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7d1913ecebe..468401dfefe 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -177,6 +177,7 @@ def process_inputs( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + num_spec_tokens: int = 0, ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. @@ -292,7 +293,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, - ) + spec_token_acceptance_counts=[0] * (num_spec_tokens + 1)) def _validate_model_inputs(self, inputs: ProcessorInputs, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6be72431dde..1eed5427edd 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -30,6 +30,7 @@ def __init__( arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, + spec_token_acceptance_counts: Optional[list[int]] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -53,6 +54,7 @@ def __init__( self._all_token_ids: list[int] = self.prompt_token_ids.copy() self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 + self.spec_token_acceptance_counts = spec_token_acceptance_counts # Multi-modal related self.mm_positions = multi_modal_placeholders or [] @@ -92,7 +94,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), - ) + spec_token_acceptance_counts=request.spec_token_acceptance_counts) def append_output_token_ids( self, From d21afbf21871f67a4e228f92bedc4c5ee26593df Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Sat, 12 Apr 2025 00:35:18 +0000 Subject: [PATCH 02/15] rebase Signed-off-by: Bryan Lu --- examples/offline_inference/eagle.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 453ae7b6f56..bb9993448aa 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -45,8 +45,12 @@ def main(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--use_v1", type=str, default="1", help='1 or 0') args = parser.parse_args() + # TODO: remove this option once EAGLE in v1 is ready. + os.environ["VLLM_USE_V1"] = args.use_v1 + model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" @@ -94,10 +98,16 @@ def main(): # to account for the token from the target model that's always going to be # accepted acceptance_counts = [0] * (args.num_spec_tokens + 1) - for output in outputs: - for step, count in enumerate( - output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count + if args.use_v1 == '1': + for output in outputs: + for step, count in enumerate( + output.spec_token_acceptance_counts[0]): + acceptance_counts[step] += count + else: + for output in outputs: + for step, count in enumerate( + output.metrics.spec_token_acceptance_counts): + acceptance_counts[step] += count print("-" * 50) print(f"mean acceptance length: \ From ddc1afd83b929b207d6636fccd0ac9c8c59ea2d4 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Mon, 14 Apr 2025 06:39:03 +0000 Subject: [PATCH 03/15] update design Signed-off-by: Bryan Lu --- vllm/v1/core/sched/scheduler.py | 15 ++++++--------- vllm/v1/engine/__init__.py | 3 +-- vllm/v1/engine/llm_engine.py | 4 ++-- vllm/v1/engine/output_processor.py | 21 +++++++++++++++++++-- vllm/v1/engine/processor.py | 2 +- vllm/v1/request.py | 4 +--- vllm/v1/spec_decode/metrics.py | 11 ++++++++--- 7 files changed, 38 insertions(+), 22 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5cd7a6980bb..ccddd341743 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -608,11 +608,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) - - for i in range(len(generated_token_ids)): - if request.spec_token_acceptance_counts is not None: - request.spec_token_acceptance_counts[i] += 1 + num_accepted_tokens=len(generated_token_ids) - 1, + request_id=req_id) cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -675,9 +672,7 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events(), - spec_token_acceptance_counts=request. - spec_token_acceptance_counts)) + events=request.take_events())) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -775,11 +770,13 @@ def make_spec_decoding_stats( spec_decoding_stats: Optional[SpecDecodingStats], num_draft_tokens: int, num_accepted_tokens: int, + request_id: str, ) -> Optional[SpecDecodingStats]: if not self.log_stats: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats() spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_accepted_tokens=num_accepted_tokens, + request_id=request_id) return spec_decoding_stats diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index fa7a66951ea..33a6225009b 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -60,7 +60,7 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] - spec_token_acceptance_counts: Optional[list[int]] + num_spec_tokens: int class EngineCoreEventType(enum.IntEnum): @@ -103,7 +103,6 @@ class EngineCoreOutput( finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None - spec_token_acceptance_counts: Optional[list[int]] = None @property def finished(self) -> bool: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 345ea9036eb..f3285735e57 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -92,7 +92,7 @@ def __init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, # FIXME: implement + log_stats=True, # FIXME: implement ) if not multiprocess_mode: @@ -232,7 +232,7 @@ def step(self) -> list[RequestOutput]: # 2) Process EngineCoreOutputs. processed_outputs = self.output_processor.process_outputs( - outputs.outputs) + outputs.outputs, scheduler_stats=outputs.scheduler_stats) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index b040f42f8e7..c5554bd20f0 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -14,7 +14,7 @@ from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, - RequestStateStats) + RequestStateStats, SchedulerStats) class RequestOutputCollector: @@ -81,6 +81,7 @@ def __init__( arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + num_spec_tokens: int = 0, ): self.request_id = request_id self.parent_req = parent_req @@ -99,6 +100,8 @@ def __init__( self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None + self.spec_token_acceptance_counts = [0] * (num_spec_tokens + 1) + @classmethod def from_new_request( cls, @@ -133,6 +136,7 @@ def from_new_request( arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, + num_spec_tokens=request.num_spec_tokens, ) def make_request_output( @@ -282,6 +286,7 @@ def process_outputs( engine_core_outputs: list[EngineCoreOutput], engine_core_timestamp: Optional[float] = None, iteration_stats: Optional[IterationStats] = None, + scheduler_stats: Optional[SchedulerStats] = None, ) -> OutputProcessorOutput: """ Process the EngineCoreOutputs: @@ -320,6 +325,8 @@ def process_outputs( self._update_stats_from_output(req_state, engine_core_output, engine_core_timestamp, iteration_stats) + self._update_stats_from_scheduler(req_id, req_state, + scheduler_stats) new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason @@ -342,7 +349,7 @@ def process_outputs( new_token_ids, finish_reason, stop_reason, - spec_token_acceptance_counts=engine_core_output. + spec_token_acceptance_counts=req_state. spec_token_acceptance_counts): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). @@ -409,3 +416,13 @@ def _update_stats_from_finished(self, req_state: RequestState, ParentRequest.observe_finished_request( req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens) + + def _update_stats_from_scheduler( + self, req_id: str, req_state: RequestState, + scheduler_stats: Optional[SchedulerStats]): + if scheduler_stats is not None and \ + scheduler_stats.spec_decoding_stats is not None: + num_accepted_tokens = scheduler_stats. \ + spec_decoding_stats.per_request_stats.get(req_id, 0) + for i in range(num_accepted_tokens): + req_state.spec_token_acceptance_counts[i] += 1 diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a362de74ed0..6488ca6ed4d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -314,7 +314,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, - spec_token_acceptance_counts=[0] * (num_spec_tokens + 1)) + num_spec_tokens=num_spec_tokens) def _validate_model_inputs(self, inputs: ProcessorInputs, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 1eed5427edd..6be72431dde 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -30,7 +30,6 @@ def __init__( arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, - spec_token_acceptance_counts: Optional[list[int]] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -54,7 +53,6 @@ def __init__( self._all_token_ids: list[int] = self.prompt_token_ids.copy() self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 - self.spec_token_acceptance_counts = spec_token_acceptance_counts # Multi-modal related self.mm_positions = multi_modal_placeholders or [] @@ -94,7 +92,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), - spec_token_acceptance_counts=request.spec_token_acceptance_counts) + ) def append_output_token_ids( self, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 7bb3c209d1d..a44523c945c 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np @@ -13,20 +13,25 @@ class SpecDecodingStats: num_draft_tokens: int = 0 num_accepted_tokens: int = 0 + per_request_stats: dict = field(default_factory=dict) def take(self): copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens) + self.num_accepted_tokens, + self.per_request_stats) self.reset() return copied def reset(self): self.num_draft_tokens = 0 self.num_accepted_tokens = 0 + self.per_request_stats = {} - def observe(self, num_draft_tokens: int, num_accepted_tokens: int): + def observe(self, num_draft_tokens: int, num_accepted_tokens: int, + request_id: str): self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens + self.per_request_stats[request_id] = num_accepted_tokens + 1 class SpecDecodingMetrics: From cbb96bca30ccf93b60bca749326e037011435891 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Mon, 14 Apr 2025 06:50:33 +0000 Subject: [PATCH 04/15] minor Signed-off-by: Bryan Lu --- vllm/outputs.py | 4 ++++ vllm/v1/engine/output_processor.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 661dfe2870a..19d8fe08eb6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -33,6 +33,10 @@ class CompletionOutput: to stop, None if the completion finished for some other reason including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. + spec_token_acceptance_counts: A list tracking the total number of + accepted tokens at each speculation step for a request. Its length + is num_spec_tokens + 1 since there is always one token generated + by the target model. """ index: int diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index c5554bd20f0..4d7b86ec951 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -192,9 +192,11 @@ def _new_request_output( ) def _new_completion_output( - self, token_ids: list[int], finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], - spec_token_acceptance_counts: Optional[list[int]] + self, + token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]], ) -> CompletionOutput: finished = finish_reason is not None From 55c5bea25a6b974296f05177bc83331746bcec52 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 17 Apr 2025 09:41:31 -0400 Subject: [PATCH 05/15] initial eagle3 attempt (broken) Signed-off-by: Benjamin Chislett --- examples/offline_inference/eagle.py | 10 +- tests/models/registry.py | 6 +- vllm/config.py | 2 +- vllm/model_executor/models/llama.py | 19 +++- vllm/model_executor/models/llama_eagle.py | 110 +++++++++++++++++----- vllm/v1/spec_decode/eagle.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 14 ++- 7 files changed, 131 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index bb9993448aa..ec7a58c68bb 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -51,8 +51,8 @@ def main(): # TODO: remove this option once EAGLE in v1 is ready. os.environ["VLLM_USE_V1"] = args.use_v1 - model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" - eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" + model_dir = "meta-llama/Llama-3.1-8B-Instruct" + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" max_model_len = 2048 @@ -114,6 +114,12 @@ def main(): {sum(acceptance_counts) / acceptance_counts[0]:.2f}") print("-" * 50) + # print acceptance at each token position + for i in range(len(acceptance_counts)): + print( + f"acceptance at token {i}: {acceptance_counts[i] / (acceptance_counts[0]):.2f}" + ) + if __name__ == "__main__": main() diff --git a/tests/models/registry.py b/tests/models/registry.py index 896b6c3bf47..8a82aea3337 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -377,7 +377,11 @@ def check_available_online( "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501, + "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + tokenizer="meta-llama/Llama-3.1-8B-Instruct"), # noqa: E501 } _TRANSFORMERS_MODELS = { diff --git a/vllm/config.py b/vllm/config.py index 08947e39bc4..203ab0c300c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2333,7 +2333,7 @@ def __post_init__(self): ) # Automatically detect the method - if "eagle-" in self.draft_model_config.model.lower(): + if "eagle" in self.draft_model_config.model.lower(): self.method = "eagle" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index caa4a5108a9..4d848a97568 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -113,6 +113,10 @@ def __init__(self, super().__init__() layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size + self.input_hidden_size = hidden_size + if hasattr(config, "input_hidden_size") and \ + config.input_hidden_size is not None: + self.input_hidden_size = config.input_hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -140,7 +144,7 @@ def __init__(self, self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size=hidden_size, + hidden_size=self.input_hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, @@ -219,6 +223,10 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size + # self.input_hidden_size = config.hidden_size + # if hasattr(config, "input_hidden_size") and \ + # config.input_hidden_size is not None: + # self.input_hidden_size = config.input_hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( @@ -356,8 +364,13 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: + extra_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): hidden_states, residual = layer(positions, hidden_states, residual) + if idx == 2 or idx == len(self.layers) // 2 or idx == len( + self.layers) - 3: + extra_hidden_states.append(hidden_states + residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -366,7 +379,7 @@ def forward( }) hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + return hidden_states, extra_hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 28ad6128c4f..5db7eace356 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch import torch.nn as nn @@ -8,12 +8,14 @@ from vllm.config import ModelConfig from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) +from vllm.v1.sample.metadata import SamplingMetadata from .utils import AutoWeightsLoader, maybe_prefix @@ -54,16 +56,40 @@ def __init__( self.config.hidden_size, prefix=maybe_prefix(prefix, "embed_tokens"), ) + self.config.input_hidden_size = 2 * self.config.hidden_size self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, - i == 0, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - ) for i in range(self.config.num_hidden_layers) + disable_input_layernorm=True, + prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + ) ]) - self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.fc = torch.nn.Linear(self.config.hidden_size * 3, self.config.hidden_size, bias=False) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + # d2t=torch.zeros((self.config.draft_vocab_size),dtype=torch.long) + # t2d=torch.zeros((self.config.vocab_size),dtype=torch.bool) + # self.register_buffer("d2t", d2t) + # self.register_buffer("t2d", t2d) + + # self.t2d = nn.Parameter( + # torch.zeros((self.config.vocab_size), dtype=torch.bool), + # requires_grad=False, + # ) + + self.input_layernorm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + self.hidden_norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) def forward( self, @@ -72,17 +98,19 @@ def forward( hidden_states: torch.Tensor, ) -> torch.Tensor: input_embeds = self.embed_tokens(input_ids) - hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - residual, - ) - return hidden_states + residual + input_embeds = self.input_layernorm(input_embeds) + hidden_states = self.hidden_norm(hidden_states) + if (hidden_states.shape != input_embeds.shape): + hidden_states = self.fc(hidden_states) + hidden_states = torch.cat((input_embeds, hidden_states), dim=-1) + + hidden_states, residual = self.layers[0]( + positions, + hidden_states, + None, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -97,6 +125,12 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: + if 'midlayer.input_layernorm' in name: + name = name.replace('midlayer.', '') + if 'midlayer.hidden_norm' in name: + name = name.replace('midlayer.', '') + if 'midlayer.' in name: + name = name.replace('midlayer.', 'layers.0.') for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -124,8 +158,19 @@ def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): prefix="model") logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.draft_vocab_size, + padding_size=(DEFAULT_VOCAB_PADDING_SIZE), + prefix="") + self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, scale=logit_scale) + self.draft_id_to_target_id = nn.Parameter( + torch.zeros((self.config.draft_vocab_size), + dtype=torch.long).type(torch.LongTensor), + requires_grad=False, + ) def forward( self, @@ -135,16 +180,39 @@ def forward( ) -> torch.Tensor: return self.model(input_ids, positions, hidden_states) + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + # pad logits from draft vocab to target vocab + # and convert indices accordingly + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full(( + logits.shape[0], + self.config.vocab_size, + ), float('-inf')) + logits_new[:, targets] = logits + return logits_new + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=None, + # skip_prefixes=(["lm_head."] + # if self.config.tie_word_embeddings else None), ) model_weights = {} for name, loaded_weight in weights: - if "lm_head" not in name: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + elif "lm_head" not in name: name = "model." + name model_weights[name] = loaded_weight diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2322463c071..e0aeeeba117 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -206,7 +206,8 @@ def load_model(self, target_model: nn.Module) -> None: loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) - self.model.lm_head = target_model.lm_head + self.model.model.embed_tokens = target_model.model.embed_tokens + # self.model.lm_head = target_model.lm_head # FIXME(woosuk): The logic here is duplicated with the main sampling code. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 70e8bd75ec9..a1b74a83139 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1061,7 +1061,7 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states, extra_hidden_states = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1192,7 +1192,9 @@ def execute_model( # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] + target_hidden_states = [ + h[:num_scheduled_tokens] for h in extra_hidden_states + ] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1213,9 +1215,12 @@ def execute_model( ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] + target_hidden_states = [ + h[token_indices] for h in extra_hidden_states + ] target_slot_mapping = attn_metadata.slot_mapping[token_indices] + target_hidden_states = torch.cat(target_hidden_states, dim=-1) draft_token_ids, draft_probs = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1286,6 +1291,7 @@ def load_model(self) -> None: if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) + pass time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds", @@ -1438,7 +1444,7 @@ def _dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - hidden_states = model( + hidden_states, _ = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, From 8067e542b0ca576ffd945e777f82e781e72b6a32 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 17 Apr 2025 17:11:21 +0000 Subject: [PATCH 06/15] fixes for EAGLE-3 --- examples/offline_inference/eagle.py | 3 +- vllm/model_executor/models/llama.py | 2 +- vllm/model_executor/models/llama_eagle.py | 88 +++++++++++++---------- vllm/v1/spec_decode/eagle.py | 10 +-- 4 files changed, 57 insertions(+), 46 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 1ed0fd6132f..fe3ed435d17 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -45,6 +45,7 @@ def parse_args(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--use_v1", type=str, default="1") return parser.parse_args() @@ -52,7 +53,7 @@ def main(): args = parse_args() - os.environ["VLLM_USE_V1"] = "1" + os.environ["VLLM_USE_V1"] = args.use_v1 model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4d848a97568..80454ce30e1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -367,10 +367,10 @@ def forward( extra_hidden_states = [] for idx, layer in enumerate( self.layers[self.start_layer:self.end_layer]): - hidden_states, residual = layer(positions, hidden_states, residual) if idx == 2 or idx == len(self.layers) // 2 or idx == len( self.layers) - 3: extra_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 5db7eace356..e0313e476e8 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -9,7 +9,9 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -23,20 +25,52 @@ class LlamaDecoderLayer(LlamaDecoderLayer): - def __init__( self, config: LlamaConfig, - disable_input_layernorm: bool, + quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config, prefix=prefix) + super().__init__(config, quant_config=quant_config, prefix=prefix) + + # override qkv + self.self_attn.qkv_proj = QKVParallelLinear( + 2 * self.hidden_size, + self.self_attn.head_dim, + self.self_attn.total_num_heads, + self.self_attn.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "qkv_proj"), + ) + + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + embeds = self.input_layernorm(embeds) + hidden_states = self.hidden_norm(hidden_states) + + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - # Skip the input_layernorm - # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 - if disable_input_layernorm: - del self.input_layernorm - self.input_layernorm = nn.Identity() + # Fully Connected + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual class LlamaModel(nn.Module): @@ -60,7 +94,6 @@ def __init__( self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, - disable_input_layernorm=True, prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), ) ]) @@ -72,25 +105,6 @@ def __init__( eps=self.config.rms_norm_eps, ) - # d2t=torch.zeros((self.config.draft_vocab_size),dtype=torch.long) - # t2d=torch.zeros((self.config.vocab_size),dtype=torch.bool) - # self.register_buffer("d2t", d2t) - # self.register_buffer("t2d", t2d) - - # self.t2d = nn.Parameter( - # torch.zeros((self.config.vocab_size), dtype=torch.bool), - # requires_grad=False, - # ) - - self.input_layernorm = RMSNorm( - self.config.hidden_size, - eps=self.config.rms_norm_eps, - ) - self.hidden_norm = RMSNorm( - self.config.hidden_size, - eps=self.config.rms_norm_eps, - ) - def forward( self, input_ids: torch.Tensor, @@ -98,19 +112,19 @@ def forward( hidden_states: torch.Tensor, ) -> torch.Tensor: input_embeds = self.embed_tokens(input_ids) - input_embeds = self.input_layernorm(input_embeds) - hidden_states = self.hidden_norm(hidden_states) - if (hidden_states.shape != input_embeds.shape): + if (hidden_states.shape[-1] != input_embeds.shape[-1]): hidden_states = self.fc(hidden_states) - hidden_states = torch.cat((input_embeds, hidden_states), dim=-1) + residual = None hidden_states, residual = self.layers[0]( positions, + input_embeds, hidden_states, - None, + residual, ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + + hidden_states, hidden_prenorm = self.norm(hidden_states, residual) + return hidden_states, hidden_prenorm def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -125,10 +139,6 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if 'midlayer.input_layernorm' in name: - name = name.replace('midlayer.', '') - if 'midlayer.hidden_norm' in name: - name = name.replace('midlayer.', '') if 'midlayer.' in name: name = name.replace('midlayer.', 'layers.0.') for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e0aeeeba117..11f27d5b43b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -84,12 +84,12 @@ def propose( ) with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states_fwd = self.model( input_ids=input_ids, hidden_states=target_hidden_states, positions=target_positions, ) - sample_hidden_states = hidden_states[last_token_indices] + sample_hidden_states = hidden_states_logits[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids, draft_probs = compute_probs_and_sample_next_token( logits, sampling_metadata) @@ -104,7 +104,7 @@ def propose( draft_probs_list = [draft_probs] positions = target_positions[last_token_indices] - hidden_states = sample_hidden_states + hidden_states = hidden_states_fwd[last_token_indices] attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] @@ -124,12 +124,12 @@ def propose( # Run the model. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( + hidden_states_logits, hidden_states = self.model( input_ids=input_ids, hidden_states=hidden_states, positions=positions, ) - logits = self.model.compute_logits(hidden_states, None) + logits = self.model.compute_logits(hidden_states_logits, None) draft_token_ids, probs = compute_probs_and_sample_next_token( logits, sampling_metadata) draft_token_ids_list.append(draft_token_ids) From c0433db34fdaf5c41761d12d8fe6938672d2b9d6 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 21 Apr 2025 10:15:03 -0400 Subject: [PATCH 07/15] remove some comments Signed-off-by: Benjamin Chislett --- vllm/model_executor/models/llama_eagle.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index e0313e476e8..d48d6ab9642 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -25,6 +26,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer): + def __init__( self, config: LlamaConfig, @@ -65,7 +67,8 @@ def forward( hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) # Fully Connected hidden_states = self.mlp(hidden_states) @@ -197,8 +200,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - # pad logits from draft vocab to target vocab - # and convert indices accordingly base = torch.arange(self.config.draft_vocab_size, device=logits.device) targets = base + self.draft_id_to_target_id logits_new = logits.new_full(( @@ -212,8 +213,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, skip_prefixes=None, - # skip_prefixes=(["lm_head."] - # if self.config.tie_word_embeddings else None), ) model_weights = {} From ebc9bde9b08d75f1310bd0a0f97eac012920b595 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 21 Apr 2025 10:26:30 -0400 Subject: [PATCH 08/15] move eagle3 llama code Signed-off-by: Benjamin Chislett --- vllm/model_executor/models/{llama_eagle.py => llama_eagle3.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm/model_executor/models/{llama_eagle.py => llama_eagle3.py} (100%) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle3.py similarity index 100% rename from vllm/model_executor/models/llama_eagle.py rename to vllm/model_executor/models/llama_eagle3.py From aa11bef1632c345224dda0aa2a56248c5de8ea55 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 21 Apr 2025 12:13:18 -0400 Subject: [PATCH 09/15] eagle3 refactor Signed-off-by: Benjamin Chislett --- examples/offline_inference/eagle.py | 8 +- tests/models/registry.py | 4 +- vllm/config.py | 8 +- vllm/engine/arg_utils.py | 3 +- vllm/model_executor/models/llama.py | 21 ++- vllm/model_executor/models/llama_eagle.py | 151 +++++++++++++++++++++ vllm/model_executor/models/llama_eagle3.py | 2 +- vllm/model_executor/models/registry.py | 1 + vllm/v1/core/sched/scheduler.py | 3 +- vllm/v1/spec_decode/eagle.py | 18 ++- vllm/v1/worker/gpu_model_runner.py | 48 +++++-- 11 files changed, 233 insertions(+), 34 deletions(-) create mode 100644 vllm/model_executor/models/llama_eagle.py diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index fe3ed435d17..f50863a7abf 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -83,8 +83,9 @@ def main(): max_model_len=max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, + compilation_config=0, speculative_config={ - "method": "eagle", + "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, @@ -120,9 +121,8 @@ def main(): # print acceptance at each token position for i in range(len(acceptance_counts)): - print( - f"acceptance at token {i}: {acceptance_counts[i] / (acceptance_counts[0]):.2f}" - ) + print(f"acceptance at token {i}:" + f"{acceptance_counts[i] / (acceptance_counts[0]):.2f}") if __name__ == "__main__": diff --git a/tests/models/registry.py b/tests/models/registry.py index 6e09c0a8779..45ba5f9202d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -386,10 +386,10 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501, - "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", - tokenizer="meta-llama/Llama-3.1-8B-Instruct"), # noqa: E501 + tokenizer="meta-llama/Llama-3.1-8B-Instruct"), } _TRANSFORMERS_MODELS = { diff --git a/vllm/config.py b/vllm/config.py index 7e2869e4eab..e9f74106300 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2131,6 +2131,7 @@ class SpeculativeConfig: Minimum size of ngram token window when using Ngram proposer, if provided. Defaults to 1. - eagle + - eagle3 - medusa - mlp_speculator - draft_model @@ -2370,9 +2371,10 @@ def __post_init__(self): ) # Automatically detect the method - if self.method == 'eagle': + if self.method == 'eagle' or self.method == 'eagle3': pass - elif "eagle-" in self.draft_model_config.model.lower(): + elif "eagle-" in self.draft_model_config.model.lower() or \ + "eagle3-" in self.draft_model_config.model.lower(): self.method = "eagle" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" @@ -2383,7 +2385,7 @@ def __post_init__(self): self.method = "draft_model" # Replace hf_config for EAGLE draft_model - if self.method == "eagle": + if self.method == "eagle" or self.method == "eagle3": if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 85b3ddfce48..262c76c8e69 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1515,7 +1515,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True - elif speculative_method == "eagle": + elif speculative_method == "eagle" or \ + speculative_method == "eagle3": is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 80454ce30e1..39e3162cd9e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -339,6 +339,8 @@ def __init__(self, else: self.norm = PPMissingLayer() + self.aux_hidden_state_layers = [] + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -364,12 +366,11 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - extra_hidden_states = [] + aux_hidden_states = [] for idx, layer in enumerate( self.layers[self.start_layer:self.end_layer]): - if idx == 2 or idx == len(self.layers) // 2 or idx == len( - self.layers) - 3: - extra_hidden_states.append(hidden_states + residual) + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -379,7 +380,10 @@ def forward( }) hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states, extra_hidden_states + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -533,6 +537,13 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: list[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> list[int]: + num_layers = len(self.model.layers) + return [2, num_layers // 2, num_layers - 3] + def _init_model(self, vllm_config: VllmConfig, prefix: str = "", diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py new file mode 100644 index 00000000000..3a265343657 --- /dev/null +++ b/vllm/model_executor/models/llama_eagle.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Set, Tuple + +import torch +import torch.nn as nn +from transformers import LlamaConfig + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import (LlamaDecoderLayer, + LlamaForCausalLM) + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +class LlamaDecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: LlamaConfig, + disable_input_layernorm: bool, + prefix: str = "", + ) -> None: + super().__init__(config, prefix=prefix) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if disable_input_layernorm: + del self.input_layernorm + self.input_layernorm = nn.Identity() + + +class LlamaModel(nn.Module): + + def __init__( + self, + *, + model_config: ModelConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = model_config.hf_config + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer( + self.config, + i == 0, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + return hidden_states + residual, hidden_states + residual + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleLlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + nn.Module.__init__(self) + self.config = model_config.hf_config + self.model = LlamaModel(model_config=model_config, + start_layer_id=start_layer_id, + prefix="model") + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.model(input_ids, positions, hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index d48d6ab9642..2f35512ed9b 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -161,7 +161,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class EagleLlamaForCausalLM(LlamaForCausalLM): +class Eagle3LlamaForCausalLM(LlamaForCausalLM): def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): nn.Module.__init__(self) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 670a4439284..671179bc521 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -211,6 +211,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ccddd341743..4305b987286 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -115,7 +115,8 @@ def __init__( cache_size=encoder_cache_size) self.num_lookahead_tokens = 0 - if speculative_config and speculative_config.method == "eagle": + if speculative_config and (speculative_config.method == "eagle" + or speculative_config.method == "eagle3"): self.num_lookahead_tokens = \ speculative_config.num_speculative_tokens diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 11f27d5b43b..26624e6e06d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -198,16 +199,23 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = EagleLlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) + if self.vllm_config.speculative_config.method == "eagle": + self.model = EagleLlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) + else: + self.model = Eagle3LlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) - self.model.model.embed_tokens = target_model.model.embed_tokens - # self.model.lm_head = target_model.lm_head + if self.vllm_config.speculative_config.method == "eagle3": + self.model.model.embed_tokens = target_model.model.embed_tokens + else: + self.model.lm_head = target_model.lm_head # FIXME(woosuk): The logic here is duplicated with the main sampling code. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 913a721cc28..052c93db791 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -159,14 +159,18 @@ def __init__( # Set up speculative decoding. self.use_spec_decode = False + self.use_aux_hidden_state_outputs = False if self.speculative_config: self.use_spec_decode = True if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.method == "eagle" or \ + self.speculative_config.method == "eagle3": self.drafter = EagleProposer(self.vllm_config, self.device) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1060,12 +1064,18 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): - hidden_states, extra_hidden_states = self.model( + output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = output + else: + hidden_states = output + if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states @@ -1163,7 +1173,8 @@ def execute_model( assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) - elif self.speculative_config.method == "eagle": + elif self.speculative_config.method == "eagle" or \ + self.speculative_config.method == "eagle3": assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] @@ -1191,9 +1202,12 @@ def execute_model( # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] - target_hidden_states = [ - h[:num_scheduled_tokens] for h in extra_hidden_states - ] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[:num_scheduled_tokens] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1214,12 +1228,16 @@ def execute_model( ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] - target_hidden_states = [ - h[token_indices] for h in extra_hidden_states - ] + if self.use_aux_hidden_state_outputs: + target_hidden_states = [ + h[token_indices] for h in aux_hidden_states + ] + else: + target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - target_hidden_states = torch.cat(target_hidden_states, dim=-1) + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat(target_hidden_states, dim=-1) draft_token_ids, draft_probs = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1290,7 +1308,9 @@ def load_model(self) -> None: if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) - pass + if self.use_aux_hidden_state_outputs: + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GiB and %.6f seconds", @@ -1443,12 +1463,16 @@ def _dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - hidden_states, _ = model( + outputs = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] From 0721cfa50a80374f687cdbba486dd7666ea8f1b0 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 21 Apr 2025 12:27:41 -0400 Subject: [PATCH 10/15] Revert "Merge remote-tracking branch 'luyuzhe111/spec_dec_stats' into eagle-temp" This reverts commit 906e2b3cd9e3143d48343813790cec9f18a7135a, reversing changes made to 1dd23386ecab7b7c50ea61b8ff37ca14d2dbc0f7. Signed-off-by: Benjamin Chislett --- examples/offline_inference/eagle.py | 17 +++---------- vllm/outputs.py | 8 ------ vllm/v1/core/sched/scheduler.py | 7 ++---- vllm/v1/engine/__init__.py | 1 - vllm/v1/engine/llm_engine.py | 23 ++++++----------- vllm/v1/engine/output_processor.py | 39 ++++++----------------------- vllm/v1/engine/processor.py | 3 +-- vllm/v1/spec_decode/metrics.py | 11 +++----- 8 files changed, 24 insertions(+), 85 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index f50863a7abf..573dcab3a9a 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -45,7 +45,6 @@ def parse_args(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) - parser.add_argument("--use_v1", type=str, default="1") return parser.parse_args() @@ -53,8 +52,6 @@ def main(): args = parse_args() - os.environ["VLLM_USE_V1"] = args.use_v1 - model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" @@ -103,16 +100,10 @@ def main(): # to account for the token from the target model that's always going to be # accepted acceptance_counts = [0] * (args.num_spec_tokens + 1) - if args.use_v1 == '1': - for output in outputs: - for step, count in enumerate( - output.spec_token_acceptance_counts[0]): - acceptance_counts[step] += count - else: - for output in outputs: - for step, count in enumerate( - output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count + for output in outputs: + for step, count in enumerate( + output.metrics.spec_token_acceptance_counts): + acceptance_counts[step] += count print("-" * 50) print(f"mean acceptance length: \ diff --git a/vllm/outputs.py b/vllm/outputs.py index 19d8fe08eb6..014e8d5d882 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -33,10 +33,6 @@ class CompletionOutput: to stop, None if the completion finished for some other reason including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. - spec_token_acceptance_counts: A list tracking the total number of - accepted tokens at each speculation step for a request. Its length - is num_spec_tokens + 1 since there is always one token generated - by the target model. """ index: int @@ -47,7 +43,6 @@ class CompletionOutput: finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None lora_request: Optional[LoRARequest] = None - spec_token_acceptance_counts: Optional[list[int]] = None def finished(self) -> bool: return self.finish_reason is not None @@ -138,9 +133,6 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - self.spec_token_acceptance_counts = [ - o.spec_token_acceptance_counts for o in outputs - ] def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4305b987286..716fc940023 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -609,8 +609,7 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1, - request_id=req_id) + num_accepted_tokens=len(generated_token_ids) - 1) cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -771,13 +770,11 @@ def make_spec_decoding_stats( spec_decoding_stats: Optional[SpecDecodingStats], num_draft_tokens: int, num_accepted_tokens: int, - request_id: str, ) -> Optional[SpecDecodingStats]: if not self.log_stats: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats() spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens, - request_id=request_id) + num_accepted_tokens=num_accepted_tokens) return spec_decoding_stats diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 98044cd2501..af4122a5107 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -60,7 +60,6 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] - num_spec_tokens: int class EngineCoreEventType(enum.IntEnum): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index be372b5eaf9..c05319f3d80 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -92,7 +92,7 @@ def __init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, - log_stats=True, # FIXME: implement + log_stats=False, # FIXME: implement ) if not multiprocess_mode: @@ -183,20 +183,11 @@ def add_request( priority: int = 0, ) -> None: # Process raw inputs into the request. - num_spec_tokens = 0 - if self.vllm_config.speculative_config is not None: - num_spec_tokens = ( - self.vllm_config.speculative_config.num_speculative_tokens) - request = self.processor.process_inputs( - request_id, - prompt, - params, - arrival_time, - lora_request, - trace_headers, - prompt_adapter_request, - priority, - num_spec_tokens=num_spec_tokens) + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) n = params.n if isinstance(params, SamplingParams) else 1 @@ -232,7 +223,7 @@ def step(self) -> list[RequestOutput]: # 2) Process EngineCoreOutputs. processed_outputs = self.output_processor.process_outputs( - outputs.outputs, scheduler_stats=outputs.scheduler_stats) + outputs.outputs) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 7f9f0b6ade7..21e2a1aee4e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -14,7 +14,7 @@ from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, - RequestStateStats, SchedulerStats) + RequestStateStats) class RequestOutputCollector: @@ -89,7 +89,6 @@ def __init__( arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, - num_spec_tokens: int = 0, ): self.request_id = request_id self.parent_req = parent_req @@ -108,8 +107,6 @@ def __init__( self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None - self.spec_token_acceptance_counts = [0] * (num_spec_tokens + 1) - @classmethod def from_new_request( cls, @@ -144,13 +141,13 @@ def from_new_request( arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, - num_spec_tokens=request.num_spec_tokens, ) def make_request_output( - self, new_token_ids: list[int], finish_reason: Optional[FinishReason], + self, + new_token_ids: list[int], + finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], - spec_token_acceptance_counts: Optional[list[int]] ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -161,10 +158,7 @@ def make_request_output( return None completion_output = self._new_completion_output( - new_token_ids, - finish_reason, - stop_reason, - spec_token_acceptance_counts=spec_token_acceptance_counts) + new_token_ids, finish_reason, stop_reason) request_id = self.request_id if self.parent_req is None: @@ -204,7 +198,6 @@ def _new_completion_output( token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], - spec_token_acceptance_counts: Optional[list[int]], ) -> CompletionOutput: finished = finish_reason is not None @@ -227,8 +220,7 @@ def _new_completion_output( logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, - stop_reason=stop_reason if finished else None, - spec_token_acceptance_counts=spec_token_acceptance_counts) + stop_reason=stop_reason if finished else None) class OutputProcessor: @@ -303,7 +295,6 @@ def process_outputs( engine_core_outputs: list[EngineCoreOutput], engine_core_timestamp: Optional[float] = None, iteration_stats: Optional[IterationStats] = None, - scheduler_stats: Optional[SchedulerStats] = None, ) -> OutputProcessorOutput: """ Process the EngineCoreOutputs: @@ -342,8 +333,6 @@ def process_outputs( self._update_stats_from_output(req_state, engine_core_output, engine_core_timestamp, iteration_stats) - self._update_stats_from_scheduler(req_id, req_state, - scheduler_stats) new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason @@ -363,11 +352,7 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, - finish_reason, - stop_reason, - spec_token_acceptance_counts=req_state. - spec_token_acceptance_counts): + new_token_ids, finish_reason, stop_reason): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) @@ -433,13 +418,3 @@ def _update_stats_from_finished(self, req_state: RequestState, ParentRequest.observe_finished_request( req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens) - - def _update_stats_from_scheduler( - self, req_id: str, req_state: RequestState, - scheduler_stats: Optional[SchedulerStats]): - if scheduler_stats is not None and \ - scheduler_stats.spec_decoding_stats is not None: - num_accepted_tokens = scheduler_stats. \ - spec_decoding_stats.per_request_stats.get(req_id, 0) - for i in range(num_accepted_tokens): - req_state.spec_token_acceptance_counts[i] += 1 diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6d4a95f47ed..afbbddb86d5 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -201,7 +201,6 @@ def process_inputs( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - num_spec_tokens: int = 0, ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. @@ -317,7 +316,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, - num_spec_tokens=num_spec_tokens) + ) def _validate_model_inputs(self, inputs: ProcessorInputs, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 5884af4aa2b..cc453b74f7e 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass, field +from dataclasses import dataclass import numpy as np @@ -13,25 +13,20 @@ class SpecDecodingStats: num_draft_tokens: int = 0 num_accepted_tokens: int = 0 - per_request_stats: dict = field(default_factory=dict) def take(self): copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens, - self.per_request_stats) + self.num_accepted_tokens) self.reset() return copied def reset(self): self.num_draft_tokens = 0 self.num_accepted_tokens = 0 - self.per_request_stats = {} - def observe(self, num_draft_tokens: int, num_accepted_tokens: int, - request_id: str): + def observe(self, num_draft_tokens: int, num_accepted_tokens: int): self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens - self.per_request_stats[request_id] = num_accepted_tokens + 1 class SpecDecodingMetrics: From 0a55267a5f28b69451b5176ab0af4a29277b474f Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 21 Apr 2025 12:34:19 -0400 Subject: [PATCH 11/15] skip broken metrics on eagle.py inference script Signed-off-by: Benjamin Chislett --- examples/offline_inference/eagle.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 573dcab3a9a..27eb86c2b0d 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -96,6 +96,9 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not hasattr(outputs, "metrics") or outputs.metrics is None: + return + # calculate the average number of accepted tokens per forward pass, +1 is # to account for the token from the target model that's always going to be # accepted From 1915bc25e968da30305387bbc588949a5c3211f6 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 21 Apr 2025 19:38:54 +0000 Subject: [PATCH 12/15] fixes for 70B Signed-off-by: Benjamin Chislett --- vllm/model_executor/models/llama_eagle3.py | 13 +++++++++---- vllm/v1/spec_decode/eagle.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 2f35512ed9b..13bca8cab64 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -100,9 +100,14 @@ def __init__( prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), ) ]) - self.fc = torch.nn.Linear(self.config.hidden_size * 3, - self.config.hidden_size, - bias=False) + if hasattr(self.config, "target_hidden_size"): + self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, + self.config.hidden_size, + bias=False) + else: + self.fc = torch.nn.Linear(self.config.hidden_size * 3, + self.config.hidden_size, + bias=False) self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, @@ -225,4 +230,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = "model." + name model_weights[name] = loaded_weight - loader.load_weights(model_weights.items()) + return loader.load_weights(model_weights.items()) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 26624e6e06d..6b914753d9a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -6,6 +6,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context +from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM @@ -13,6 +14,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata +logger = init_logger(__name__) + class EagleProposer: @@ -208,13 +211,17 @@ def load_model(self, target_model: nn.Module) -> None: model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) - self.model.load_weights( + loaded_weights = self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) if self.vllm_config.speculative_config.method == "eagle3": - self.model.model.embed_tokens = target_model.model.embed_tokens + if "model.embed_tokens.weight" not in loaded_weights: + logger.info( + "Loading EAGLE embedding weights from the target model.") + self.model.model.embed_tokens = target_model.model.embed_tokens else: + logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_model.lm_head From a3c22cfb79342fb17e4050291c7e5494a5634b48 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 21 Apr 2025 16:29:21 -0400 Subject: [PATCH 13/15] fix nitpicks Signed-off-by: Benjamin Chislett --- tests/models/registry.py | 2 +- vllm/config.py | 4 ++-- vllm/engine/arg_utils.py | 3 +-- vllm/model_executor/models/llama.py | 4 ---- vllm/v1/spec_decode/eagle.py | 1 + 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index a8827a16df7..2f7ddfd907e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -391,7 +391,7 @@ def check_available_online( "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501, + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 "Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", diff --git a/vllm/config.py b/vllm/config.py index 08f0a9cfdbc..634317d114c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2389,7 +2389,7 @@ def __post_init__(self): ) # Automatically detect the method - if self.method == 'eagle' or self.method == 'eagle3': + if self.method in ('eagle', 'eagle3'): pass elif "eagle-" in self.draft_model_config.model.lower() or \ "eagle3-" in self.draft_model_config.model.lower(): @@ -2403,7 +2403,7 @@ def __post_init__(self): self.method = "draft_model" # Replace hf_config for EAGLE draft_model - if self.method == "eagle" or self.method == "eagle3": + if self.method in ("eagle", "eagle3"): if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 984c714f5ea..932580a097f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1462,8 +1462,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if speculative_method: if speculative_method in ("ngram", "[ngram]"): is_ngram_enabled = True - elif speculative_method == "eagle" or \ - speculative_method == "eagle3": + elif speculative_method in ("eagle", "eagle3"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 39e3162cd9e..1f8b0f9b886 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -223,10 +223,6 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - # self.input_hidden_size = config.hidden_size - # if hasattr(config, "input_hidden_size") and \ - # config.input_hidden_size is not None: - # self.input_hidden_size = config.input_hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e5681e03d33..29dfafbddd5 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -236,6 +236,7 @@ def load_model(self, target_model: nn.Module) -> None: model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) else: + assert self.vllm_config.speculative_config.method == "eagle3" self.model = Eagle3LlamaForCausalLM( model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) From 7ca6242e0d46acfe4633bb37e305f59a6f8abd32 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 24 Apr 2025 09:55:33 -0400 Subject: [PATCH 14/15] cleanup for pr Signed-off-by: Benjamin Chislett --- examples/offline_inference/eagle.py | 1 - vllm/config.py | 6 ++++++ vllm/model_executor/models/llama.py | 14 +++++--------- vllm/model_executor/models/llama_eagle.py | 3 ++- vllm/model_executor/models/llama_eagle3.py | 1 - 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 27eb86c2b0d..474b745a610 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -80,7 +80,6 @@ def main(): max_model_len=max_model_len, max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, - compilation_config=0, speculative_config={ "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", "model": eagle_dir, diff --git a/vllm/config.py b/vllm/config.py index 634317d114c..f9dc6b300be 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2601,6 +2601,12 @@ def _verify_args(self) -> None: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: + raise ValueError( + "Eagle3 is only supported for Llama models. " + f"Got {self.target_model_config.hf_text_config.model_type=}") + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1f8b0f9b886..3c157c20923 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -113,10 +113,6 @@ def __init__(self, super().__init__() layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size - self.input_hidden_size = hidden_size - if hasattr(config, "input_hidden_size") and \ - config.input_hidden_size is not None: - self.input_hidden_size = config.input_hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -144,7 +140,7 @@ def __init__(self, self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size=self.input_hidden_size, + hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, @@ -335,7 +331,7 @@ def __init__(self, else: self.norm = PPMissingLayer() - self.aux_hidden_state_layers = [] + self.aux_hidden_state_layers: tuple[int] = tuple() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( @@ -533,12 +529,12 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def set_aux_hidden_state_layers(self, layers: list[int]) -> None: + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: self.model.aux_hidden_state_layers = layers - def get_eagle3_aux_hidden_state_layers(self) -> list[int]: + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: num_layers = len(self.model.layers) - return [2, num_layers // 2, num_layers - 3] + return (2, num_layers // 2, num_layers - 3) def _init_model(self, vllm_config: VllmConfig, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 3a265343657..06f7cb08a7c 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -82,7 +82,8 @@ def forward( hidden_states, residual, ) - return hidden_states + residual, hidden_states + residual + hidden_states = hidden_states + residual + return hidden_states, hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 13bca8cab64..ffbb9d75a06 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -93,7 +93,6 @@ def __init__( self.config.hidden_size, prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.config.input_hidden_size = 2 * self.config.hidden_size self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, From e6de768670bcb261b2fb1b262221d5ab3b1db1b4 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 24 Apr 2025 09:55:53 -0400 Subject: [PATCH 15/15] tests for eagle3 Signed-off-by: Benjamin Chislett --- tests/v1/e2e/test_spec_decode.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 67371498059..eb638e2efac 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -50,12 +50,15 @@ def sampling_config(): @pytest.fixture def model_name(): - return "meta-llama/Meta-Llama-3-8B-Instruct" + return "meta-llama/Llama-3.1-8B-Instruct" -@pytest.fixture def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3-Instruct-8B" + return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + +def eagle3_model_name(): + return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" def test_ngram_correctness( @@ -102,12 +105,13 @@ def test_ngram_correctness( del spec_llm +@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, - eagle_model_name: str, + use_eagle3: bool, ): ''' Compare the outputs of a original LLM and a speculative LLM @@ -120,11 +124,13 @@ def test_eagle_correctness( ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + spec_model_name = eagle3_model_name( + ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, speculative_config={ - "method": "eagle", - "model": eagle_model_name, + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, "num_speculative_tokens": 3, }, max_model_len=1024,