[Refactor] Consolidate SupportsEagle (#36063)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-13 19:22:40 -04:00
committed by GitHub
parent 54a6db827f
commit 8b346309a5
24 changed files with 229 additions and 235 deletions

View File

@@ -13,15 +13,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.models.interfaces import EagleModelMixin
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
class PredictableLlamaModel(nn.Module): class PredictableLlamaModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.aux_hidden_state_layers = tuple[int, ...]()
# Create minimal embed_tokens for embedding # Create minimal embed_tokens for embedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (

View File

@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
EagleModelMixin,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsPP, SupportsPP,
@@ -384,7 +385,7 @@ class AfmoeDecoderLayer(nn.Module):
"inputs_embeds": 0, "inputs_embeds": 0,
} }
) )
class AfmoeModel(nn.Module): class AfmoeModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
@@ -421,8 +422,6 @@ class AfmoeModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
@@ -453,15 +452,14 @@ class AfmoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate( for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -691,13 +689,6 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,

View File

@@ -60,7 +60,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
@@ -313,7 +319,7 @@ class ApertusDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class ApertusModel(nn.Module): class ApertusModel(nn.Module, EagleModelMixin):
def __init__( def __init__(
self, self,
*, *,
@@ -357,8 +363,6 @@ class ApertusModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
@@ -384,13 +388,14 @@ class ApertusModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate( for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -472,7 +477,9 @@ class ApertusModel(nn.Module):
return loaded_params return loaded_params
class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class ApertusForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
# LoRA specific attributes # LoRA specific attributes
@@ -520,13 +527,6 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model( def _init_model(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,

View File

@@ -32,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
@@ -170,7 +176,7 @@ class ArceeDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class ArceeModel(nn.Module): class ArceeModel(nn.Module, EagleModelMixin):
"""The transformer model backbone for Arcee (embedding layer + stacked """The transformer model backbone for Arcee (embedding layer + stacked
decoder blocks + final norm).""" decoder blocks + final norm)."""
@@ -218,10 +224,6 @@ class ArceeModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
# For optional capturing of intermediate hidden states
# (not used by default)
self.aux_hidden_state_layers: tuple[int, ...] = tuple()
# Prepare factory for empty intermediate tensors # Prepare factory for empty intermediate tensors
# (for pipeline scheduling) # (for pipeline scheduling)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
@@ -253,15 +255,14 @@ class ArceeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states: list[torch.Tensor] = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate( for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual
) # capture pre-layer hidden state if needed
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
# Send intermediate results to the next pipeline stage # Send intermediate results to the next pipeline stage
@@ -348,7 +349,9 @@ class ArceeModel(nn.Module):
return loaded_params return loaded_params
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class ArceeForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
"""Arcee Model for causal language modeling, integrated with vLLM """Arcee Model for causal language modeling, integrated with vLLM
runtime.""" runtime."""

View File

@@ -47,7 +47,13 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
@@ -256,7 +262,7 @@ class TransformerBlock(torch.nn.Module):
@support_torch_compile @support_torch_compile
class GptOssModel(nn.Module): class GptOssModel(nn.Module, EagleModelMixin):
def __init__( def __init__(
self, self,
*, *,
@@ -285,7 +291,6 @@ class GptOssModel(nn.Module):
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size ["hidden_states", "residual"], self.config.hidden_size
) )
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids) return self.embedding(input_ids)
@@ -309,12 +314,13 @@ class GptOssModel(nn.Module):
x = intermediate_tensors["hidden_states"] x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, x, residual
)
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x + residual)
x, residual = layer(x, positions, residual) x, residual = layer(x, positions, residual)
self._maybe_add_hidden_state(aux_hidden_states, i + 1, x, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": x, "residual": residual}) return IntermediateTensors({"hidden_states": x, "residual": residual})
x, _ = self.norm(x, residual) x, _ = self.norm(x, residual)
@@ -1141,7 +1147,9 @@ class GptOssModel(nn.Module):
) )
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): class GptOssForCausalLM(
nn.Module, SupportsPP, SupportsEagle, SupportsEagle3, SupportsLoRA
):
is_3d_moe_weight: bool = True is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
@@ -1197,13 +1205,6 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)

View File

@@ -66,7 +66,14 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
@@ -586,7 +593,7 @@ class HunYuanDecoderLayer(nn.Module):
"inputs_embeds": 0, "inputs_embeds": 0,
} }
) )
class HunYuanModel(nn.Module): class HunYuanModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
@@ -629,7 +636,6 @@ class HunYuanModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
@@ -654,13 +660,10 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config) cla_factor = _get_cla_factor(self.config)
prev_kv_states = None prev_kv_states = None
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for i, layer in enumerate( for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual, kv_states = layer( hidden_states, residual, kv_states = layer(
positions, positions,
hidden_states, hidden_states,
@@ -673,6 +676,10 @@ class HunYuanModel(nn.Module):
else: else:
prev_kv_states = None prev_kv_states = None
self._maybe_add_hidden_state(
aux_hidden_states, i + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}
@@ -904,7 +911,9 @@ class HunYuanModel(nn.Module):
return loaded_params return loaded_params
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): class HunyuanV1ModelBase(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -943,13 +952,6 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,

View File

@@ -86,6 +86,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
@@ -801,6 +802,7 @@ class HunYuanVLForConditionalGeneration(
SupportsPP, SupportsPP,
SupportsQuant, SupportsQuant,
SupportsXDRoPE, SupportsXDRoPE,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
): ):
# To ensure correct weight loading and mapping. # To ensure correct weight loading and mapping.
@@ -988,13 +990,6 @@ class HunYuanVLForConditionalGeneration(
multimodal_embeddings += tuple(image_embeddings) multimodal_embeddings += tuple(image_embeddings)
return multimodal_embeddings return multimodal_embeddings
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,

View File

@@ -1273,6 +1273,25 @@ def supports_any_eagle(
return supports_eagle(model) or supports_eagle3(model) return supports_eagle(model) or supports_eagle3(model)
class EagleModelMixin:
aux_hidden_state_layers: tuple[int, ...] = ()
def _set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.aux_hidden_state_layers = layers
def _maybe_add_hidden_state(
self,
aux_hidden_states: list[torch.Tensor],
layer_idx: int,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> list[torch.Tensor]:
if layer_idx in self.aux_hidden_state_layers:
value = hidden_states + residual if residual is not None else hidden_states
aux_hidden_states.append(value)
return aux_hidden_states
@runtime_checkable @runtime_checkable
class SupportsEagle(SupportsEagleBase, Protocol): class SupportsEagle(SupportsEagleBase, Protocol):
"""The interface required for models that support """The interface required for models that support
@@ -1320,24 +1339,48 @@ class SupportsEagle3(SupportsEagleBase, Protocol):
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
""" """
Set which layers should output auxiliary Set which layers should output auxiliary hidden states for EAGLE-3.
hidden states for EAGLE-3.
Args: Args:
layers: Tuple of layer indices that should output auxiliary layers: Tuple of layer indices that should output auxiliary
hidden states. hidden states.
""" """
... parent_ref = self
if hasattr(self, "get_language_model"):
parent_ref = self.get_language_model()
elif hasattr(self, "language_model"):
parent_ref = self.language_model
assert hasattr(parent_ref, "model"), (
"Model instance must have 'model' attribute to set number of layers"
)
assert isinstance(parent_ref.model, EagleModelMixin), (
"Model instance must inherit from EagleModelMixin to set auxiliary layers"
)
parent_ref.model._set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
""" """
Get the layer indices that should output auxiliary hidden states Get the default layer indices that should output auxiliary hidden states
for EAGLE-3. for EAGLE-3 for this model. Models can override this method to provide
different default layers based on their architecture, but it is encouraged
to instead include the layer specification in the model's config if possible.
Returns: Returns:
Tuple of layer indices for auxiliary hidden state outputs. Tuple of layer indices for auxiliary hidden state outputs.
""" """
... parent_ref = self
if hasattr(self, "get_language_model"):
parent_ref = self.get_language_model()
elif hasattr(self, "language_model"):
parent_ref = self.language_model
assert hasattr(parent_ref, "model"), (
"Model instance must have 'model' attribute to get number of layers"
)
assert hasattr(parent_ref.model, "layers"), (
"Model instance must have 'layers' attribute to get number of layers"
)
num_layers = len(parent_ref.model.layers)
return (2, num_layers // 2, num_layers - 3)
@overload @overload

View File

@@ -61,6 +61,7 @@ from vllm.v1.attention.backend import AttentionType
from .adapters import as_embedding_model, as_seq_cls_model from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import ( from .interfaces import (
EagleModelMixin,
SupportsEagle, SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
@@ -351,7 +352,7 @@ def llama_model_invariants(
# mark_unbacked_dims={"input_ids": 0}, # mark_unbacked_dims={"input_ids": 0},
shape_invariants=llama_model_invariants shape_invariants=llama_model_invariants
) )
class LlamaModel(nn.Module): class LlamaModel(nn.Module, EagleModelMixin):
def __init__( def __init__(
self, self,
*, *,
@@ -389,8 +390,6 @@ class LlamaModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
@@ -417,15 +416,16 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate( for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, residual, **extra_layer_kwargs positions, hidden_states, residual, **extra_layer_kwargs
) )
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -556,18 +556,6 @@ class LlamaForCausalLM(
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Override to return default layers for Llama
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model( def _init_model(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,

View File

@@ -55,6 +55,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
@@ -503,7 +504,12 @@ def init_vision_tower_for_llava(
dummy_inputs=LlavaDummyInputsBuilder, dummy_inputs=LlavaDummyInputsBuilder,
) )
class LlavaForConditionalGeneration( class LlavaForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3 nn.Module,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsEagle,
SupportsEagle3,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -527,13 +533,6 @@ class LlavaForConditionalGeneration(
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.get_language_model().model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.get_language_model().model.layers)
return (2, num_layers // 2, num_layers - 3)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()

View File

@@ -682,13 +682,6 @@ class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)

View File

@@ -63,7 +63,13 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
is_pp_missing_parameter, is_pp_missing_parameter,
@@ -391,7 +397,7 @@ class MiniCPMDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class MiniCPMModel(nn.Module): class MiniCPMModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
@@ -413,8 +419,6 @@ class MiniCPMModel(nn.Module):
self._init_layers(prefix, config, cache_config, quant_config) self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size ["hidden_states", "residual"], self.config.hidden_size
) )
@@ -455,19 +459,18 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate( for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
residual, residual,
) )
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -550,7 +553,9 @@ class MiniCPMModel(nn.Module):
return loaded_params return loaded_params
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): class MiniCPMForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -611,13 +616,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,

View File

@@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
@@ -409,7 +410,12 @@ def init_vision_tower_for_llava(
dummy_inputs=Mistral3DummyInputsBuilder, dummy_inputs=Mistral3DummyInputsBuilder,
) )
class Mistral3ForConditionalGeneration( class Mistral3ForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3 nn.Module,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsEagle,
SupportsEagle3,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -433,13 +439,6 @@ class Mistral3ForConditionalGeneration(
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.get_language_model().model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.get_language_model().model.layers)
return (2, num_layers // 2, num_layers - 3)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()

View File

@@ -798,20 +798,16 @@ class Llama4ForConditionalGeneration(
self.num_moe_layers = len(self.moe_layers) self.num_moe_layers = len(self.moe_layers)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM) # Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "set_aux_hidden_state_layers") assert hasattr(self.language_model, "set_aux_hidden_state_layers")
self.language_model.set_aux_hidden_state_layers(layers) self.language_model.set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
# Delegate to underlying language model (Llama4ForCausalLM) # Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") assert hasattr(
return self.language_model.get_eagle3_aux_hidden_state_layers() self.language_model, "get_eagle3_default_aux_hidden_state_layers"
)
return self.language_model.get_eagle3_default_aux_hidden_state_layers()
def set_eplb_state( def set_eplb_state(
self, self,

View File

@@ -62,7 +62,13 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
@@ -349,7 +355,7 @@ def qwen_2_model_invariants(
}, },
shape_invariants=qwen_2_model_invariants, shape_invariants=qwen_2_model_invariants,
) )
class Qwen2Model(nn.Module): class Qwen2Model(nn.Module, EagleModelMixin):
def __init__( def __init__(
self, self,
*, *,
@@ -410,8 +416,6 @@ class Qwen2Model(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
@@ -433,13 +437,14 @@ class Qwen2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate( for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -519,7 +524,9 @@ class Qwen2Model(nn.Module):
return loaded_params return loaded_params
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): class Qwen2ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -566,13 +573,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,

View File

@@ -89,6 +89,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMRoPE, SupportsMRoPE,
@@ -1000,6 +1001,7 @@ class Qwen2_5_VLForConditionalGeneration(
SupportsLoRA, SupportsLoRA,
SupportsPP, SupportsPP,
SupportsQuant, SupportsQuant,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsMultiModalPruning, SupportsMultiModalPruning,
SupportsMRoPE, SupportsMRoPE,
@@ -1143,13 +1145,6 @@ class Qwen2_5_VLForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Qwen2_5_VLImageInputs | None: ) -> Qwen2_5_VLImageInputs | None:

View File

@@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta from vllm.transformers_utils.config import set_default_rope_theta
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
@@ -258,7 +258,9 @@ class Qwen3Model(Qwen2Model):
) )
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): class Qwen3ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -307,13 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)

View File

@@ -65,7 +65,14 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
@@ -427,7 +434,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Qwen3MoeModel(nn.Module): class Qwen3MoeModel(nn.Module, EagleModelMixin):
def __init__( def __init__(
self, self,
*, *,
@@ -461,8 +468,6 @@ class Qwen3MoeModel(nn.Module):
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
# Track layers for auxiliary hidden state outputs (EAGLE3)
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
@@ -485,18 +490,17 @@ class Qwen3MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, hidden_states, residual
)
for layer_idx, layer in enumerate( for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer), islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer, start=self.start_layer,
): ):
# Collect auxiliary hidden states if specified
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_state = (
hidden_states + residual if residual is not None else hidden_states
)
aux_hidden_states.append(aux_hidden_state)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -666,7 +670,7 @@ class Qwen3MoeModel(nn.Module):
class Qwen3MoeForCausalLM( class Qwen3MoeForCausalLM(
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts nn.Module, SupportsPP, SupportsLoRA, SupportsEagle, SupportsEagle3, MixtureOfExperts
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
@@ -751,13 +755,6 @@ class Qwen3MoeForCausalLM(
moe.n_redundant_experts = self.num_redundant_experts moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map() moe.experts.update_expert_map()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)

View File

@@ -101,6 +101,7 @@ from vllm.utils.math_utils import round_up
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMRoPE, SupportsMRoPE,
@@ -1275,13 +1276,10 @@ class Qwen3LLMModel(Qwen3Model):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in islice( for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer enumerate(self.layers), self.start_layer, self.end_layer
): ):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
@@ -1295,6 +1293,9 @@ class Qwen3LLMModel(Qwen3Model):
hidden_states hidden_states
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
) )
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -1351,6 +1352,7 @@ class Qwen3VLForConditionalGeneration(
SupportsLoRA, SupportsLoRA,
SupportsPP, SupportsPP,
SupportsMRoPE, SupportsMRoPE,
SupportsEagle,
SupportsEagle3, SupportsEagle3,
SupportsMultiModalPruning, SupportsMultiModalPruning,
): ):
@@ -1449,13 +1451,6 @@ class Qwen3VLForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_deepstack_input_embeds( def _get_deepstack_input_embeds(
self, self,
num_tokens: int, num_tokens: int,

View File

@@ -102,19 +102,17 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, hidden_states, residual
)
for layer_idx, layer in islice( for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer enumerate(self.layers), self.start_layer, self.end_layer
): ):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
residual, residual,
) )
if deepstack_input_embeds is not None and layer_idx in range( if deepstack_input_embeds is not None and layer_idx in range(
0, len(deepstack_input_embeds) 0, len(deepstack_input_embeds)
): ):
@@ -123,6 +121,10 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
) )
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}

View File

@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.models.interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsPP,
)
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module):
return loaded_params return loaded_params
class StepDecoderModel(nn.Module): class StepDecoderModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = getattr(
config, "aux_hidden_state_layers", ()
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], ["hidden_states", "residual"],
config.hidden_size, config.hidden_size,
@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = [] aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
if idx in self.aux_hidden_state_layers:
if residual is None:
aux_hidden_states.append(hidden_states)
else:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module):
return hidden_states return hidden_states
class Step1ForCausalLM(nn.Module, SupportsPP): class Step1ForCausalLM(nn.Module, SupportsPP, SupportsEagle, SupportsEagle3):
packed_modules_mapping = STEP_PACKED_MODULES_MAPPING packed_modules_mapping = STEP_PACKED_MODULES_MAPPING
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@@ -618,6 +618,6 @@ class Base(
# Ensure that the capture hooks are installed before dynamo traces the model # Ensure that the capture hooks are installed before dynamo traces the model
maybe_install_capturing_hooks(self.model) maybe_install_capturing_hooks(self.model)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3) return (2, num_layers // 2, num_layers - 3)

View File

@@ -27,7 +27,7 @@ def set_eagle3_aux_hidden_state_layers(
if aux_layers: if aux_layers:
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers) logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
else: else:
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers() aux_layers = eagle3_model.get_eagle3_default_aux_hidden_state_layers()
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers) logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
eagle3_model.set_aux_hidden_state_layers(aux_layers) eagle3_model.set_aux_hidden_state_layers(aux_layers)

View File

@@ -4556,7 +4556,9 @@ class GPUModelRunner(
aux_layers, aux_layers,
) )
else: else:
aux_layers = self.model.get_eagle3_aux_hidden_state_layers() aux_layers = (
self.model.get_eagle3_default_aux_hidden_state_layers()
)
self.model.set_aux_hidden_state_layers(aux_layers) self.model.set_aux_hidden_state_layers(aux_layers)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()