[Refactor] Consolidate SupportsEagle (#36063)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-13 19:22:40 -04:00
committed by GitHub
parent 54a6db827f
commit 8b346309a5
24 changed files with 229 additions and 235 deletions

View File

@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.interfaces import (
EagleModelMixin,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
@@ -384,7 +385,7 @@ class AfmoeDecoderLayer(nn.Module):
"inputs_embeds": 0,
}
)
class AfmoeModel(nn.Module):
class AfmoeModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -421,8 +422,6 @@ class AfmoeModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
@@ -453,15 +452,14 @@ class AfmoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -691,13 +689,6 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -60,7 +60,13 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -313,7 +319,7 @@ class ApertusDecoderLayer(nn.Module):
@support_torch_compile
class ApertusModel(nn.Module):
class ApertusModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
@@ -357,8 +363,6 @@ class ApertusModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
@@ -384,13 +388,14 @@ class ApertusModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -472,7 +477,9 @@ class ApertusModel(nn.Module):
return loaded_params
class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class ApertusForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
# LoRA specific attributes
@@ -520,13 +527,6 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model(
self,
vllm_config: VllmConfig,

View File

@@ -32,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -170,7 +176,7 @@ class ArceeDecoderLayer(nn.Module):
@support_torch_compile
class ArceeModel(nn.Module):
class ArceeModel(nn.Module, EagleModelMixin):
"""The transformer model backbone for Arcee (embedding layer + stacked
decoder blocks + final norm)."""
@@ -218,10 +224,6 @@ class ArceeModel(nn.Module):
else:
self.norm = PPMissingLayer()
# For optional capturing of intermediate hidden states
# (not used by default)
self.aux_hidden_state_layers: tuple[int, ...] = tuple()
# Prepare factory for empty intermediate tensors
# (for pipeline scheduling)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
@@ -253,15 +255,14 @@ class ArceeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states: list[torch.Tensor] = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual
) # capture pre-layer hidden state if needed
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
# Send intermediate results to the next pipeline stage
@@ -348,7 +349,9 @@ class ArceeModel(nn.Module):
return loaded_params
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class ArceeForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
"""Arcee Model for causal language modeling, integrated with vLLM
runtime."""

View File

@@ -47,7 +47,13 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@@ -256,7 +262,7 @@ class TransformerBlock(torch.nn.Module):
@support_torch_compile
class GptOssModel(nn.Module):
class GptOssModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
@@ -285,7 +291,6 @@ class GptOssModel(nn.Module):
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
@@ -309,12 +314,13 @@ class GptOssModel(nn.Module):
x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, x, residual
)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x + residual)
x, residual = layer(x, positions, residual)
self._maybe_add_hidden_state(aux_hidden_states, i + 1, x, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": x, "residual": residual})
x, _ = self.norm(x, residual)
@@ -1141,7 +1147,9 @@ class GptOssModel(nn.Module):
)
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
class GptOssForCausalLM(
nn.Module, SupportsPP, SupportsEagle, SupportsEagle3, SupportsLoRA
):
is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
@@ -1197,13 +1205,6 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
self.model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@@ -66,7 +66,14 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -586,7 +593,7 @@ class HunYuanDecoderLayer(nn.Module):
"inputs_embeds": 0,
}
)
class HunYuanModel(nn.Module):
class HunYuanModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -629,7 +636,6 @@ class HunYuanModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -654,13 +660,10 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual, kv_states = layer(
positions,
hidden_states,
@@ -673,6 +676,10 @@ class HunYuanModel(nn.Module):
else:
prev_kv_states = None
self._maybe_add_hidden_state(
aux_hidden_states, i + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
@@ -904,7 +911,9 @@ class HunYuanModel(nn.Module):
return loaded_params
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
class HunyuanV1ModelBase(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -943,13 +952,6 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
else:
self.lm_head = PPMissingLayer()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -86,6 +86,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
@@ -801,6 +802,7 @@ class HunYuanVLForConditionalGeneration(
SupportsPP,
SupportsQuant,
SupportsXDRoPE,
SupportsEagle,
SupportsEagle3,
):
# To ensure correct weight loading and mapping.
@@ -988,13 +990,6 @@ class HunYuanVLForConditionalGeneration(
multimodal_embeddings += tuple(image_embeddings)
return multimodal_embeddings
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -1273,6 +1273,25 @@ def supports_any_eagle(
return supports_eagle(model) or supports_eagle3(model)
class EagleModelMixin:
aux_hidden_state_layers: tuple[int, ...] = ()
def _set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.aux_hidden_state_layers = layers
def _maybe_add_hidden_state(
self,
aux_hidden_states: list[torch.Tensor],
layer_idx: int,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> list[torch.Tensor]:
if layer_idx in self.aux_hidden_state_layers:
value = hidden_states + residual if residual is not None else hidden_states
aux_hidden_states.append(value)
return aux_hidden_states
@runtime_checkable
class SupportsEagle(SupportsEagleBase, Protocol):
"""The interface required for models that support
@@ -1320,24 +1339,48 @@ class SupportsEagle3(SupportsEagleBase, Protocol):
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""
Set which layers should output auxiliary
hidden states for EAGLE-3.
Set which layers should output auxiliary hidden states for EAGLE-3.
Args:
layers: Tuple of layer indices that should output auxiliary
hidden states.
"""
...
parent_ref = self
if hasattr(self, "get_language_model"):
parent_ref = self.get_language_model()
elif hasattr(self, "language_model"):
parent_ref = self.language_model
assert hasattr(parent_ref, "model"), (
"Model instance must have 'model' attribute to set number of layers"
)
assert isinstance(parent_ref.model, EagleModelMixin), (
"Model instance must inherit from EagleModelMixin to set auxiliary layers"
)
parent_ref.model._set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""
Get the layer indices that should output auxiliary hidden states
for EAGLE-3.
Get the default layer indices that should output auxiliary hidden states
for EAGLE-3 for this model. Models can override this method to provide
different default layers based on their architecture, but it is encouraged
to instead include the layer specification in the model's config if possible.
Returns:
Tuple of layer indices for auxiliary hidden state outputs.
"""
...
parent_ref = self
if hasattr(self, "get_language_model"):
parent_ref = self.get_language_model()
elif hasattr(self, "language_model"):
parent_ref = self.language_model
assert hasattr(parent_ref, "model"), (
"Model instance must have 'model' attribute to get number of layers"
)
assert hasattr(parent_ref.model, "layers"), (
"Model instance must have 'layers' attribute to get number of layers"
)
num_layers = len(parent_ref.model.layers)
return (2, num_layers // 2, num_layers - 3)
@overload

View File

@@ -61,6 +61,7 @@ from vllm.v1.attention.backend import AttentionType
from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
@@ -351,7 +352,7 @@ def llama_model_invariants(
# mark_unbacked_dims={"input_ids": 0},
shape_invariants=llama_model_invariants
)
class LlamaModel(nn.Module):
class LlamaModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
@@ -389,8 +390,6 @@ class LlamaModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
@@ -417,15 +416,16 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions, hidden_states, residual, **extra_layer_kwargs
)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -556,18 +556,6 @@ class LlamaForCausalLM(
self.model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Override to return default layers for Llama
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model(
self,
vllm_config: VllmConfig,

View File

@@ -55,6 +55,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
@@ -503,7 +504,12 @@ def init_vision_tower_for_llava(
dummy_inputs=LlavaDummyInputsBuilder,
)
class LlavaForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
nn.Module,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsEagle,
SupportsEagle3,
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -527,13 +533,6 @@ class LlavaForConditionalGeneration(
raise ValueError("Only image modality is supported")
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.get_language_model().model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.get_language_model().model.layers)
return (2, num_layers // 2, num_layers - 3)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()

View File

@@ -682,13 +682,6 @@ class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@@ -63,7 +63,13 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
@@ -391,7 +397,7 @@ class MiniCPMDecoderLayer(nn.Module):
@support_torch_compile
class MiniCPMModel(nn.Module):
class MiniCPMModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -413,8 +419,6 @@ class MiniCPMModel(nn.Module):
self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)
@@ -455,19 +459,18 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -550,7 +553,9 @@ class MiniCPMModel(nn.Module):
return loaded_params
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
class MiniCPMForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -611,13 +616,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
@@ -409,7 +410,12 @@ def init_vision_tower_for_llava(
dummy_inputs=Mistral3DummyInputsBuilder,
)
class Mistral3ForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
nn.Module,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsEagle,
SupportsEagle3,
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -433,13 +439,6 @@ class Mistral3ForConditionalGeneration(
raise ValueError("Only image modality is supported")
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.get_language_model().model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.get_language_model().model.layers)
return (2, num_layers // 2, num_layers - 3)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()

View File

@@ -798,20 +798,16 @@ class Llama4ForConditionalGeneration(
self.num_moe_layers = len(self.moe_layers)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
self.language_model.set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers()
assert hasattr(
self.language_model, "get_eagle3_default_aux_hidden_state_layers"
)
return self.language_model.get_eagle3_default_aux_hidden_state_layers()
def set_eplb_state(
self,

View File

@@ -62,7 +62,13 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta
from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -349,7 +355,7 @@ def qwen_2_model_invariants(
},
shape_invariants=qwen_2_model_invariants,
)
class Qwen2Model(nn.Module):
class Qwen2Model(nn.Module, EagleModelMixin):
def __init__(
self,
*,
@@ -410,8 +416,6 @@ class Qwen2Model(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -433,13 +437,14 @@ class Qwen2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -519,7 +524,9 @@ class Qwen2Model(nn.Module):
return loaded_params
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
class Qwen2ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -566,13 +573,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -89,6 +89,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsMRoPE,
@@ -1000,6 +1001,7 @@ class Qwen2_5_VLForConditionalGeneration(
SupportsLoRA,
SupportsPP,
SupportsQuant,
SupportsEagle,
SupportsEagle3,
SupportsMultiModalPruning,
SupportsMRoPE,
@@ -1143,13 +1145,6 @@ class Qwen2_5_VLForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Qwen2_5_VLImageInputs | None:

View File

@@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta
from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
@@ -258,7 +258,9 @@ class Qwen3Model(Qwen2Model):
)
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
class Qwen3ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -307,13 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@@ -65,7 +65,14 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -427,7 +434,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
@support_torch_compile
class Qwen3MoeModel(nn.Module):
class Qwen3MoeModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
@@ -461,8 +468,6 @@ class Qwen3MoeModel(nn.Module):
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
# Track layers for auxiliary hidden state outputs (EAGLE3)
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -485,18 +490,17 @@ class Qwen3MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, hidden_states, residual
)
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
# Collect auxiliary hidden states if specified
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_state = (
hidden_states + residual if residual is not None else hidden_states
)
aux_hidden_states.append(aux_hidden_state)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -666,7 +670,7 @@ class Qwen3MoeModel(nn.Module):
class Qwen3MoeForCausalLM(
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle, SupportsEagle3, MixtureOfExperts
):
packed_modules_mapping = {
"qkv_proj": [
@@ -751,13 +755,6 @@ class Qwen3MoeForCausalLM(
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@@ -101,6 +101,7 @@ from vllm.utils.math_utils import round_up
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsMRoPE,
@@ -1275,13 +1276,10 @@ class Qwen3LLMModel(Qwen3Model):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions,
hidden_states,
@@ -1295,6 +1293,9 @@ class Qwen3LLMModel(Qwen3Model):
hidden_states
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -1351,6 +1352,7 @@ class Qwen3VLForConditionalGeneration(
SupportsLoRA,
SupportsPP,
SupportsMRoPE,
SupportsEagle,
SupportsEagle3,
SupportsMultiModalPruning,
):
@@ -1449,13 +1451,6 @@ class Qwen3VLForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_deepstack_input_embeds(
self,
num_tokens: int,

View File

@@ -102,19 +102,17 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, hidden_states, residual
)
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and layer_idx in range(
0, len(deepstack_input_embeds)
):
@@ -123,6 +121,10 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}

View File

@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsPP,
)
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module):
return loaded_params
class StepDecoderModel(nn.Module):
class StepDecoderModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = getattr(
config, "aux_hidden_state_layers", ()
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"],
config.hidden_size,
@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
if idx in self.aux_hidden_state_layers:
if residual is None:
aux_hidden_states.append(hidden_states)
else:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module):
return hidden_states
class Step1ForCausalLM(nn.Module, SupportsPP):
class Step1ForCausalLM(nn.Module, SupportsPP, SupportsEagle, SupportsEagle3):
packed_modules_mapping = STEP_PACKED_MODULES_MAPPING
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@@ -618,6 +618,6 @@ class Base(
# Ensure that the capture hooks are installed before dynamo traces the model
maybe_install_capturing_hooks(self.model)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3)