Add support for Eagle with separate lm-head and embed_tokens layers (#28549)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić
2025-11-15 15:12:02 +01:00
committed by GitHub
parent 085a525332
commit e439c784fa
12 changed files with 205 additions and 64 deletions

View File

@@ -26,7 +26,7 @@ from vllm.model_executor.models.deepseek_v2 import (
)
from vllm.utils import init_logger
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@@ -250,6 +250,7 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(

View File

@@ -85,7 +85,7 @@ from vllm.v1.attention.backends.mla.indexer import (
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
@@ -1311,7 +1311,7 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts):
class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],

View File

@@ -932,13 +932,73 @@ def supports_transcription(
@runtime_checkable
class SupportsEagle3(Protocol):
class SupportsEagleBase(Protocol):
"""Base interface for models that support EAGLE-based speculative decoding."""
has_own_lm_head: bool = False
"""
A flag that indicates this model has trained its own lm_head.
"""
has_own_embed_tokens: bool = False
"""
A flag that indicates this model has trained its own input embeddings.
"""
@overload
def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ...
@overload
def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ...
def supports_any_eagle(
model: type[object] | object,
) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]:
"""Check if model supports any EAGLE variant (1, 2, or 3)."""
return supports_eagle(model) or supports_eagle3(model)
@runtime_checkable
class SupportsEagle(SupportsEagleBase, Protocol):
"""The interface required for models that support
EAGLE3 speculative decoding."""
EAGLE-1 and EAGLE-2 speculative decoding."""
supports_eagle: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports EAGLE-1 and EAGLE-2
speculative decoding.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@overload
def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ...
@overload
def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ...
def supports_eagle(
model: type[object] | object,
) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]:
return isinstance(model, SupportsEagle)
@runtime_checkable
class SupportsEagle3(SupportsEagleBase, Protocol):
"""The interface required for models that support
EAGLE-3 speculative decoding."""
supports_eagle3: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports EAGLE3
A flag that indicates this model supports EAGLE-3
speculative decoding.
Note:
@@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol):
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""
Set which layers should output auxiliary
hidden states for EAGLE3.
hidden states for EAGLE-3.
Args:
layers: Tuple of layer indices that should output auxiliary
@@ -960,7 +1020,7 @@ class SupportsEagle3(Protocol):
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""
Get the layer indices that should output auxiliary hidden states
for EAGLE3.
for EAGLE-3.
Returns:
Tuple of layer indices for auxiliary hidden state outputs.

View File

@@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -529,7 +529,9 @@ class LlamaModel(nn.Module):
return loaded_params
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
class LlamaForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],

View File

@@ -35,7 +35,7 @@ from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausa
from vllm.model_executor.models.utils import extract_layer_index
from .interfaces import SupportsMultiModal
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@@ -212,6 +212,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
process_eagle_weight(self, name)
return name, weight
loader = AutoWeightsLoader(

View File

@@ -17,7 +17,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@@ -179,6 +179,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
name, loaded_weight = inputs
if "lm_head" not in name:
name = "model." + name
process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(

View File

@@ -23,7 +23,7 @@ from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
logger = init_logger(__name__)
@@ -324,6 +324,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)
skip_substrs = []
if not includes_draft_id_mapping:

View File

@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
@@ -52,6 +52,7 @@ from .utils import (
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
process_eagle_weight,
)
@@ -289,7 +290,7 @@ class EagleMiniCPMModel(nn.Module):
return loaded_params
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -376,8 +377,13 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
def transform(inputs):
name, loaded_weight = inputs
process_eagle_weight(self, name)
return name, loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
return loader.load_weights(map(transform, weights))

View File

@@ -19,6 +19,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
@@ -825,3 +826,25 @@ direct_register_custom_op(
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def process_eagle_weight(
model: nn.Module,
name: str,
) -> None:
"""
Update EAGLE model flags based on loaded weight name.
This should be called during weight loading to detect if a model
has its own lm_head or embed_tokens weight.
Args:
model: The model instance (must support EAGLE)
name: The name of the weight to process
"""
if not supports_any_eagle(model):
return
# To prevent overriding with target model's layers
if "lm_head" in name:
model.has_own_lm_head = True
if "embed_tokens" in name:
model.has_own_embed_tokens = True

View File

@@ -991,6 +991,7 @@ class EagleProposer:
target_language_model = target_model.get_language_model()
else:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"):
@@ -1002,52 +1003,92 @@ class EagleProposer:
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
# Check if shapes match and we found the embedding
eagle_shape = self.model.model.embed_tokens.weight.shape
target_shape = target_embed_tokens.weight.shape
if eagle_shape == target_shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding"
" with the target model."
)
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
share_embeddings = False
if hasattr(self.model, "has_own_embed_tokens"):
# EAGLE model
if not self.model.has_own_embed_tokens:
share_embeddings = True
logger.info(
"Detected EAGLE model without its own embed_tokens in the"
" checkpoint. Sharing target model embedding weights with the"
" draft model."
)
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
and torch.equal(
target_embed_tokens.weight, self.model.model.embed_tokens.weight
)
):
share_embeddings = True
logger.info(
"Detected EAGLE model with embed_tokens identical to the target"
" model. Sharing target model embedding weights with the draft"
" model."
)
else:
logger.info(
"Detected EAGLE model with distinct embed_tokens weights. "
"Keeping separate embedding weights from the target model."
)
else:
# MTP model
share_embeddings = True
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"
" from the target model."
"Detected MTP model. "
"Sharing target model embedding weights with the draft model."
)
if share_embeddings:
if hasattr(self.model.model, "embed_tokens"):
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"
"The draft model's vocab embedding will be loaded separately"
" from the target model."
)
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3":
if hasattr(target_language_model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_language_model.lm_head
else:
if (
hasattr(self.model, "lm_head")
and hasattr(target_language_model, "lm_head")
and self.model.lm_head.weight.shape
== target_language_model.lm_head.weight.shape
):
share_lm_head = False
if hasattr(self.model, "has_own_lm_head"):
# EAGLE model
if not self.model.has_own_lm_head:
share_lm_head = True
logger.info(
"Assuming the EAGLE head shares the same lm_head"
" with the target model."
"Detected EAGLE model without its own lm_head in the checkpoint. "
"Sharing target model lm_head weights with the draft model."
)
elif (
hasattr(target_language_model, "lm_head")
and isinstance(target_language_model.lm_head.weight, torch.Tensor)
and isinstance(self.model.lm_head.weight, torch.Tensor)
and torch.equal(
target_language_model.lm_head.weight, self.model.lm_head.weight
)
):
share_lm_head = True
logger.info(
"Detected EAGLE model with lm_head identical to the target model. "
"Sharing target model lm_head weights with the draft model."
)
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
else:
logger.info(
"The EAGLE head's lm_head will be loaded separately"
" from the target model."
"Detected EAGLE model with distinct lm_head weights. "
"Keeping separate lm_head weights from the target model."
)
else:
# MTP model
share_lm_head = True
logger.info(
"Detected MTP model. "
"Sharing target model lm_head weights with the draft model."
)
if share_lm_head and hasattr(target_language_model, "lm_head"):
if hasattr(self.model, "lm_head"):
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
@torch.inference_mode()
def dummy_run(