Add Eagle and Eagle3 support to Transformers modeling backend (#30340)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-12-11 17:02:10 +00:00
committed by GitHub
parent aa3c250c48
commit 8781cd6b88
2 changed files with 94 additions and 8 deletions

View File

@@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.interfaces import (
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
SupportsQuant,
@@ -92,7 +94,15 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
class Base(
nn.Module,
VllmModel,
SupportsQuant,
SupportsLoRA,
SupportsPP,
SupportsEagle,
SupportsEagle3,
):
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights`
# Attrs for weight loading (see self.load_weights)
self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes.
"""
"""Ignore unexpected weights whose qualname starts with these prefixes."""
self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes."""
# Attrs for Eagle3 (see self.set_aux_hidden_state_layers)
self._target_class: type[nn.Module] = nn.Module
"""Target class for Eagle3 aux hidden state recording."""
self._layer_names: dict[int, str] = {}
"""Mapping from layer index to layer name for Eagle3."""
self._output_aux_hidden_states_kwargs: dict[str, bool] = {}
"""Kwargs to pass to model forward for Eagle3 aux hidden states."""
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
@@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
for child_name, child_module in module.named_children():
new_module = child_module
qual_name = maybe_prefix(prefix, child_name)
# Populate Eagle3 attrs
if (
isinstance(module, nn.ModuleList)
and len(module) == self.text_config.num_hidden_layers
):
self._target_class = type(child_module)
layer_name = qual_name.removeprefix("model.")
self._layer_names[int(child_name)] = layer_name
# Replace modules as needed
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
@@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
else:
position_ids = positions[None, ...]
hidden_states = self.model(
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=position_ids,
attention_instances=self.attention_instances,
return_dict=False,
**self._output_aux_hidden_states_kwargs,
**kwargs,
)[0][0, ...] # we remove batch dimension for now
)
# We must remove the batch dimension from these outputs
hidden_states = outputs[0][0, ...]
if self._output_aux_hidden_states_kwargs:
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(
@@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
f"Transformers modeling backend requires transformers>={required} "
f"for {feature}, but got {installed}"
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.check_version("5.0.0.dev0", "Eagle3 support")
from transformers.utils.generic import OutputRecorder
# The default value in PreTrainedModel is None
if self.model._can_record_outputs is None:
self.model._can_record_outputs = {}
target_class = self._target_class
for layer in layers:
# layer - 1 because we want the input to the layer
layer_name = self._layer_names[layer - 1]
layer_key = f"aux_hidden_state_{layer}"
aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name)
self.model._can_record_outputs[layer_key] = aux_hidden_state_i
self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3)