diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 474b745a610..91e2f68ecff 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -36,6 +36,10 @@ def parse_args(): help="downloaded from the eagle repo " \ "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" ) + parser.add_argument("--method", + type=str, + default='eagle', + choices=['eagle', 'eagle3']) parser.add_argument("--max_num_seqs", type=int, default=8) parser.add_argument("--num_prompts", type=int, default=80) parser.add_argument("--num_spec_tokens", type=int, default=2) @@ -53,7 +57,13 @@ def main(): args = parse_args() model_dir = "meta-llama/Llama-3.1-8B-Instruct" - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + + if args.method == 'eagle': + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + elif args.method == 'eagle3': + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + else: + raise ValueError(f"unknown method: {args.method}") max_model_len = 2048 @@ -81,7 +91,7 @@ def main(): max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config={ - "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", + "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7012131d053..a1570b7eccc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -347,8 +347,12 @@ def configure_post_pass(self): PASS_KEY = "post_grad_custom_post_pass" if PASS_KEY in inductor_config: # Config should automatically wrap all inductor passes - assert isinstance(inductor_config[PASS_KEY], InductorPass) - self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + assert (inductor_config[PASS_KEY].uuid() == + self.post_grad_pass_manager.uuid()) + else: + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: @@ -408,8 +412,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: ) self.compilation_config.cache_dir = cache_dir - cache_dir = self.compilation_config.cache_dir + if compilation_counter.num_graphs_seen > 0: + cache_dir = self.compilation_config.cache_dir + \ + f'-{compilation_counter.num_graphs_seen}' + else: + cache_dir = self.compilation_config.cache_dir os.makedirs(cache_dir, exist_ok=True) + self.compilation_config.cache_dir = cache_dir rank = vllm_config.parallel_config.rank dp_rank = vllm_config.parallel_config.data_parallel_rank local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 56e53ac2b81..76655bd71b1 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -6,7 +6,8 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.config import ModelConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -37,17 +38,19 @@ def __init__( self.input_layernorm = nn.Identity() +@support_torch_compile class LlamaModel(nn.Module): def __init__( self, *, - model_config: ModelConfig, - start_layer_id: int = 0, + vllm_config: VllmConfig, prefix: str = "", + start_layer_id: int = 0, ) -> None: super().__init__() - self.config = model_config.hf_config + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -75,8 +78,7 @@ def forward( hidden_states = self.fc( torch.cat((input_embeds, hidden_states), dim=-1)) residual = None - for i in range(len(self.layers)): - layer = self.layers[i] + for layer in self.layers: hidden_states, residual = layer( positions, hidden_states, @@ -117,12 +119,13 @@ def load_weights(self, weights: Iterable[Tuple[str, class EagleLlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): nn.Module.__init__(self) - self.config = model_config.hf_config - self.model = LlamaModel(model_config=model_config, - start_layer_id=start_layer_id, - prefix="model") + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + self.model = LlamaModel(vllm_config=vllm_config, + prefix="model", + start_layer_id=start_layer_id) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.config.vocab_size, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 0b18e4a8fe2..c42f19fee17 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -6,7 +6,7 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -167,8 +167,9 @@ def load_weights(self, weights: Iterable[Tuple[str, class Eagle3LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): nn.Module.__init__(self) + model_config = vllm_config.speculative_config.draft_model_config self.config = model_config.hf_config self.model = LlamaModel(model_config=model_config, start_layer_id=start_layer_id, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8c45ca9a319..81508c2e069 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader @@ -26,10 +26,41 @@ def __init__( device: torch.device, ): self.vllm_config = vllm_config + self.method = self.vllm_config.speculative_config.method self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size + + self.dtype = vllm_config.model_config.dtype + + self.max_num_tokens = vllm_config.scheduler_config \ + .max_num_batched_tokens + + self.hidden_size = vllm_config.model_config.get_hidden_size() + + # TODO: make eagle3 compatible with cudagraph + self.use_cuda_graph = self.method != 'eagle3' and \ + (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) + + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # persistent buffers for cuda graph + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=device) + + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + @@ -59,13 +90,12 @@ def propose( batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 - input_ids = torch.empty_like(target_token_ids) # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - input_ids[:-1] = target_token_ids[1:] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - input_ids[last_token_indices] = next_token_ids + self.input_ids[last_token_indices] = next_token_ids # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() @@ -88,14 +118,30 @@ def propose( prefix_kv_lens=None, suffix_kv_lens=None, ) + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states_logits, hidden_states_fwd = self.model( - input_ids=input_ids, - hidden_states=target_hidden_states, - positions=target_positions, + if self.method == 'eagle': + self.hidden_states[:num_tokens] = target_hidden_states + hidden_states = self.hidden_states + else: + # TODO: make eagle3 compatible with cuda graph + hidden_states = target_hidden_states + + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=hidden_states[:num_input_tokens], ) - sample_hidden_states = hidden_states_logits[last_token_indices] + sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -108,13 +154,20 @@ def propose( draft_token_ids_list = [draft_token_ids] positions = target_positions[last_token_indices] - hidden_states = hidden_states_fwd[last_token_indices] + hidden_states = hidden_states[last_token_indices] + if self.use_cuda_graph and \ + batch_size <= self.cudagraph_batch_sizes[-1]: + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + else: + input_batch_size = batch_size attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] for _ in range(self.num_speculative_tokens - 1): # Update the inputs. - input_ids = draft_token_ids_list[-1] + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_list[-1].int() positions += 1 # NOTE(woosuk): We should handle the case where the draft model @@ -152,14 +205,27 @@ def propose( attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + + if self.method == 'eagle': + # TODO: make eagle3 compatible with cudagraph. + self.hidden_states[:batch_size] = hidden_states + hidden_states = self.hidden_states + # Run the model. - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states_logits, hidden_states = self.model( - input_ids=input_ids, - hidden_states=hidden_states, - positions=clamped_positions, + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=input_batch_size): + last_hidden_states, hidden_states = self.model( + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=hidden_states[:input_batch_size], ) - logits = self.model.compute_logits(hidden_states_logits, None) + hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size], + None) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -227,13 +293,11 @@ def load_model(self, target_model: nn.Module) -> None: draft_model_cls, arch = ModelRegistry.resolve_model_cls( draft_model_config.architectures) self.model = draft_model_cls( - model_config=draft_model_config, + vllm_config=self.vllm_config, start_layer_id=target_layer_num).to(target_device) loaded_weights = self.model.load_weights( - loader.get_all_weights( - self.vllm_config.speculative_config.draft_model_config, - self.model)) + loader.get_all_weights(draft_model_config, self.model)) if self.vllm_config.speculative_config.method == "eagle3": if "model.embed_tokens.weight" not in loaded_weights: logger.info( @@ -243,6 +307,20 @@ def load_model(self, target_model: nn.Module) -> None: logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_model.lm_head + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + ) -> None: + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + if self.method == 'eagle': + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4711beadbd9..41de305a016 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1106,7 +1106,6 @@ def execute_model( # For mid-pipeline stages, return the hidden states. return hidden_states - hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1172,7 +1171,7 @@ def execute_model( # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states, + hidden_states[:num_scheduled_tokens], scheduler_output, ) @@ -1222,15 +1221,12 @@ def execute_model( if spec_decode_metadata is None: # input_ids can be None for multimodal models. - # We need to slice token_ids, positions, and hidden_states - # because the eagle head does not use cuda graph and should - # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: - target_hidden_states = [ - h[:num_scheduled_tokens] for h in aux_hidden_states - ] + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping @@ -1254,15 +1250,12 @@ def execute_model( target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] if self.use_aux_hidden_state_outputs: - target_hidden_states = [ - h[token_indices] for h in aux_hidden_states - ] + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat(target_hidden_states, dim=-1) draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1506,6 +1499,11 @@ def _dummy_run( else: hidden_states = outputs + if self.use_spec_decode and \ + self.speculative_config.method in ('eagle', 'eagle3'): + assert isinstance(self.drafter, EagleProposer) + self.drafter.dummy_run(num_tokens) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices]