diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b6a1d2db3..801741eca 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -52,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -442,7 +442,7 @@ class Qwen2Model(nn.Module): return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -488,6 +488,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor,