diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 57d258229..37f87d7b6 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -66,6 +66,9 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator +from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import ( + set_eagle3_aux_hidden_state_layers, +) from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler from vllm.v1.worker.gpu.states import RequestState @@ -133,14 +136,42 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.output_copy_stream = torch.cuda.Stream(self.device) self.output_copy_event = torch.cuda.Event() + # Pipeline parallelism. + self.pp_size = self.parallel_config.pipeline_parallel_size + self.use_pp = self.pp_size > 1 + if self.use_pp: + self.is_first_pp_rank = get_pp_group().is_first_rank + self.is_last_pp_rank = get_pp_group().is_last_rank + else: + self.is_first_pp_rank = True + self.is_last_pp_rank = True + + # Decode context parallelism. + self.dcp_size = self.parallel_config.decode_context_parallel_size + self.use_dcp = self.dcp_size > 1 + self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0 + self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size + + self.speculator = None + self.use_aux_hidden_state_outputs = False if self.speculative_config is not None: self.do_spec_decode = True self.num_speculative_steps = self.speculative_config.num_speculative_tokens - self.speculator = init_speculator(self.vllm_config, self.device) + if self.is_last_pp_rank: + self.speculator = init_speculator(self.vllm_config, self.device) + + if self.speculative_config.method == "eagle3": + # EAGLE3 may require auxiliary hidden states from target model outputs. + self.use_aux_hidden_state_outputs = True + if self.pp_size > 1: + raise ValueError("EAGLE3 with pipeline parallel is not supported.") else: self.do_spec_decode = False self.num_speculative_steps = 0 - self.speculator = None + + # Draft tokens propagation - for spec-dec + struct outputs. + self.draft_tokens_handler = DraftTokensHandler(self.device) + self.req_states = RequestState( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -176,28 +207,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) # LoRA-related workers. self.lora_state = LoraState(max_num_reqs=self.max_num_reqs) - - # Draft tokens propagation - for spec-dec + struct outputs. - self.draft_tokens_handler = DraftTokensHandler(self.device) - # KV Connector if configured. self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR - # Pipeline parallelism. - self.use_pp = self.parallel_config.pipeline_parallel_size > 1 - if self.use_pp: - self.is_first_pp_rank = get_pp_group().is_first_rank - self.is_last_pp_rank = get_pp_group().is_last_rank - else: - self.is_first_pp_rank = True - self.is_last_pp_rank = True - - # Decode context parallelism. - self.dcp_size = self.parallel_config.decode_context_parallel_size - self.use_dcp = self.dcp_size > 1 - self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0 - self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size - def update_max_model_len(self, max_model_len: int) -> None: self.max_model_len = max_model_len self.req_states.max_model_len = max_model_len @@ -220,7 +232,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.model = self.load_lora_model( self.model, self.vllm_config, self.device ) - if self.do_spec_decode: + + if self.use_aux_hidden_state_outputs: + assert self.speculative_config is not None + set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config) + if self.speculator is not None: self.speculator.load_model(self.model) time_after_load = time.perf_counter() @@ -271,7 +287,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.kv_cache_config, self.vllm_config, self.device ) check_attention_cp_compatibility(self.vllm_config) - if self.do_spec_decode: + if self.speculator is not None: # HACK(woosuk) self.speculator.set_attn( self.kv_cache_config, @@ -359,7 +375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): return None, None assert self.execute_model_state is not None - hidden_states, input_batch, _ = self.execute_model_state + hidden_states, _, input_batch, _ = self.execute_model_state assert hidden_states is not None # Last PP rank always has hidden_states sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -399,7 +415,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert sample_hidden_states is not None self._dummy_sampler_run(sample_hidden_states) - if self.do_spec_decode: + if self.speculator is not None: num_tokens_across_dp = make_num_tokens_across_dp( self.parallel_config.data_parallel_size, self.max_num_tokens ) @@ -465,7 +481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_cache_config=self.kv_cache_config, has_lora=self.lora_config is not None, ) - if self.do_spec_decode: + if self.speculator is not None: self.speculator.capture_model() end_time = time.perf_counter() @@ -964,9 +980,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # NOTE(woosuk): Here, we don't need to pass the input tensors, # because they are already copied to the CUDA graph input buffers. self.kv_connector.pre_forward(scheduler_output) - hidden_states = self.cudagraph_manager.run_fullgraph( + model_output = self.cudagraph_manager.run_fullgraph( input_batch.num_tokens_after_padding ) + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None else: # For piecewise and eager mode, just call model(). positions = input_batch.positions @@ -998,12 +1019,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): slot_mapping=input_batch.slot_mappings, ): self.kv_connector.pre_forward(scheduler_output) - hidden_states = self.model( + model_output = self.model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors, ) + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None kv_connector_output = self.kv_connector.post_forward(scheduler_output) @@ -1011,12 +1037,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Non-last PP rank: return IntermediateTensors for sending. assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output - self.execute_model_state = (None, input_batch, kv_connector_output) + self.execute_model_state = (None, None, input_batch, kv_connector_output) return hidden_states - assert isinstance(hidden_states, torch.Tensor) # Last rank (or no PP): hidden_states is a tensor for sampling. - self.execute_model_state = (hidden_states, input_batch, kv_connector_output) + assert isinstance(hidden_states, torch.Tensor) + self.execute_model_state = ( + hidden_states, + aux_hidden_states, + input_batch, + kv_connector_output, + ) return None @torch.inference_mode() @@ -1024,7 +1055,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, grammar_output: GrammarOutput | None ) -> AsyncOutput | ModelRunnerOutput | None: assert self.execute_model_state is not None - hidden_states, input_batch, kv_connector_output = self.execute_model_state + hidden_states, aux_hidden_states, input_batch, kv_connector_output = ( + self.execute_model_state + ) self.execute_model_state = None # type: ignore if not self.is_last_pp_rank: @@ -1084,11 +1117,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.postprocess( input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected ) - if self.do_spec_decode: + if self.speculator is not None: draft_tokens = self.propose_draft( input_batch, hidden_states, - None, # aux_hidden_states + aux_hidden_states, num_sampled, num_rejected, ) diff --git a/vllm/v1/worker/gpu/spec_decode/__init__.py b/vllm/v1/worker/gpu/spec_decode/__init__.py index 07026a512..536b7526b 100644 --- a/vllm/v1/worker/gpu/spec_decode/__init__.py +++ b/vllm/v1/worker/gpu/spec_decode/__init__.py @@ -9,7 +9,7 @@ def init_speculator(vllm_config: VllmConfig, device: torch.device): speculative_config = vllm_config.speculative_config assert speculative_config is not None if speculative_config.use_eagle(): - from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator + from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator return EagleSpeculator(vllm_config, device) raise NotImplementedError(f"{speculative_config.method} is not supported yet.") diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/__init__.py b/vllm/v1/worker/gpu/spec_decode/eagle/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py similarity index 100% rename from vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py rename to vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py b/vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py new file mode 100644 index 000000000..d76d69355 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + +import torch.nn as nn + +from vllm.config import SpeculativeConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import SupportsEagle3, supports_eagle3 + +logger = init_logger(__name__) + + +def set_eagle3_aux_hidden_state_layers( + model: nn.Module, + spec_config: SpeculativeConfig, +) -> None: + if not supports_eagle3(model): + raise RuntimeError("Model does not support EAGLE3 interface") + # mypy may infer the class-level overload for supports_eagle3. + # Narrow explicitly to the runtime protocol instance. + if isinstance(model, type): + raise RuntimeError("Expected model instance for EAGLE3 configuration") + eagle3_model = cast(SupportsEagle3, model) + + aux_layers = get_eagle3_aux_layers_from_config(spec_config) + if aux_layers: + logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers) + else: + aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers() + logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers) + eagle3_model.set_aux_hidden_state_layers(aux_layers) + + +def get_eagle3_aux_layers_from_config( + spec_config: SpeculativeConfig, +) -> tuple[int, ...] | None: + if not (spec_config and spec_config.draft_model_config): + return None + hf_config = spec_config.draft_model_config.hf_config + if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"): + return None + layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) + return None diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py similarity index 97% rename from vllm/v1/worker/gpu/spec_decode/eagle.py rename to vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index abbde270f..3cd8afee7 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig @@ -20,7 +19,8 @@ from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample -from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager +from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager +from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model logger = init_logger(__name__) @@ -73,18 +73,7 @@ class EagleSpeculator: self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) def load_model(self, target_model: nn.Module) -> None: - from vllm.compilation.backends import set_model_tag - - with set_model_tag("eagle_head"): - self.model = get_model( - vllm_config=self.vllm_config, model_config=self.draft_model_config - ) - - share_lm_head = True - if share_lm_head and hasattr(target_model, "lm_head"): - if hasattr(self.model, "lm_head"): - del self.model.lm_head - self.model.lm_head = target_model.lm_head + self.model = load_eagle_model(target_model, self.vllm_config) def set_attn( self, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/utils.py b/vllm/v1/worker/gpu/spec_decode/eagle/utils.py new file mode 100644 index 000000000..ee37eadb2 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/eagle/utils.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.model_executor.model_loader import get_model + + +def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Module: + from vllm.compilation.backends import set_model_tag + + speculative_config = vllm_config.speculative_config + assert speculative_config is not None + draft_model_config = speculative_config.draft_model_config + with set_model_tag("eagle_head"): + eagle_model = get_model( + vllm_config=vllm_config, model_config=draft_model_config + ) + + # Share target embeddings when the draft checkpoint does not include + # its own vocab embedding table. + share_embeddings = True + if hasattr(eagle_model, "has_own_embed_tokens"): + share_embeddings = not eagle_model.has_own_embed_tokens + if share_embeddings: + target_language_model = ( + target_model.get_language_model() + if hasattr(target_model, "get_language_model") + else target_model + ) + inner_model = getattr(target_language_model, "model", None) + target_embed_tokens = None + if inner_model is not None: + if hasattr(inner_model, "embed_tokens"): + target_embed_tokens = inner_model.embed_tokens + elif hasattr(inner_model, "embedding"): + target_embed_tokens = inner_model.embedding + if target_embed_tokens is not None and hasattr(eagle_model, "model"): + if hasattr(eagle_model.model, "embed_tokens"): + del eagle_model.model.embed_tokens + eagle_model.model.embed_tokens = target_embed_tokens + + # Only share target lm_head when the draft model does not own one. + share_lm_head = True + if hasattr(eagle_model, "has_own_lm_head"): + share_lm_head = not eagle_model.has_own_lm_head + if share_lm_head and hasattr(target_model, "lm_head"): + if hasattr(eagle_model, "lm_head"): + del eagle_model.lm_head + eagle_model.lm_head = target_model.lm_head + + return eagle_model