[Refactor] Consolidate SupportsEagle (#36063)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
54a6db827f
commit
8b346309a5
@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
@@ -384,7 +385,7 @@ class AfmoeDecoderLayer(nn.Module):
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
)
|
||||
class AfmoeModel(nn.Module):
|
||||
class AfmoeModel(nn.Module, EagleModelMixin):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -421,8 +422,6 @@ class AfmoeModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
@@ -453,15 +452,14 @@ class AfmoeModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -691,13 +689,6 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
@@ -60,7 +60,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -313,7 +319,7 @@ class ApertusDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ApertusModel(nn.Module):
|
||||
class ApertusModel(nn.Module, EagleModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -357,8 +363,6 @@ class ApertusModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
@@ -384,13 +388,14 @@ class ApertusModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -472,7 +477,9 @@ class ApertusModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class ApertusForCausalLM(
|
||||
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
# LoRA specific attributes
|
||||
@@ -520,13 +527,6 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@@ -32,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -170,7 +176,7 @@ class ArceeDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class ArceeModel(nn.Module):
|
||||
class ArceeModel(nn.Module, EagleModelMixin):
|
||||
"""The transformer model backbone for Arcee (embedding layer + stacked
|
||||
decoder blocks + final norm)."""
|
||||
|
||||
@@ -218,10 +224,6 @@ class ArceeModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
# For optional capturing of intermediate hidden states
|
||||
# (not used by default)
|
||||
self.aux_hidden_state_layers: tuple[int, ...] = tuple()
|
||||
|
||||
# Prepare factory for empty intermediate tensors
|
||||
# (for pipeline scheduling)
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
@@ -253,15 +255,14 @@ class ArceeModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states: list[torch.Tensor] = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual
|
||||
) # capture pre-layer hidden state if needed
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
# Send intermediate results to the next pipeline stage
|
||||
@@ -348,7 +349,9 @@ class ArceeModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class ArceeForCausalLM(
|
||||
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||
):
|
||||
"""Arcee Model for causal language modeling, integrated with vLLM
|
||||
runtime."""
|
||||
|
||||
|
||||
@@ -47,7 +47,13 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
@@ -256,7 +262,7 @@ class TransformerBlock(torch.nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GptOssModel(nn.Module):
|
||||
class GptOssModel(nn.Module, EagleModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -285,7 +291,6 @@ class GptOssModel(nn.Module):
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], self.config.hidden_size
|
||||
)
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embedding(input_ids)
|
||||
@@ -309,12 +314,13 @@ class GptOssModel(nn.Module):
|
||||
x = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state(
|
||||
[], self.start_layer, x, residual
|
||||
)
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
if i in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(x if residual is None else x + residual)
|
||||
x, residual = layer(x, positions, residual)
|
||||
self._maybe_add_hidden_state(aux_hidden_states, i + 1, x, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": x, "residual": residual})
|
||||
x, _ = self.norm(x, residual)
|
||||
@@ -1141,7 +1147,9 @@ class GptOssModel(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
class GptOssForCausalLM(
|
||||
nn.Module, SupportsPP, SupportsEagle, SupportsEagle3, SupportsLoRA
|
||||
):
|
||||
is_3d_moe_weight: bool = True
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
@@ -1197,13 +1205,6 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
|
||||
@@ -66,7 +66,14 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
MixtureOfExperts,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -586,7 +593,7 @@ class HunYuanDecoderLayer(nn.Module):
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
)
|
||||
class HunYuanModel(nn.Module):
|
||||
class HunYuanModel(nn.Module, EagleModelMixin):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -629,7 +636,6 @@ class HunYuanModel(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -654,13 +660,10 @@ class HunYuanModel(nn.Module):
|
||||
|
||||
cla_factor = _get_cla_factor(self.config)
|
||||
prev_kv_states = None
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for i, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if i in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
|
||||
hidden_states, residual, kv_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
@@ -673,6 +676,10 @@ class HunYuanModel(nn.Module):
|
||||
else:
|
||||
prev_kv_states = None
|
||||
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, i + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
@@ -904,7 +911,9 @@ class HunYuanModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
class HunyuanV1ModelBase(
|
||||
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -943,13 +952,6 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
@@ -86,6 +86,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
@@ -801,6 +802,7 @@ class HunYuanVLForConditionalGeneration(
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
SupportsXDRoPE,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
):
|
||||
# To ensure correct weight loading and mapping.
|
||||
@@ -988,13 +990,6 @@ class HunYuanVLForConditionalGeneration(
|
||||
multimodal_embeddings += tuple(image_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.language_model.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.language_model.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
@@ -1273,6 +1273,25 @@ def supports_any_eagle(
|
||||
return supports_eagle(model) or supports_eagle3(model)
|
||||
|
||||
|
||||
class EagleModelMixin:
|
||||
aux_hidden_state_layers: tuple[int, ...] = ()
|
||||
|
||||
def _set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.aux_hidden_state_layers = layers
|
||||
|
||||
def _maybe_add_hidden_state(
|
||||
self,
|
||||
aux_hidden_states: list[torch.Tensor],
|
||||
layer_idx: int,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> list[torch.Tensor]:
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
value = hidden_states + residual if residual is not None else hidden_states
|
||||
aux_hidden_states.append(value)
|
||||
return aux_hidden_states
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsEagle(SupportsEagleBase, Protocol):
|
||||
"""The interface required for models that support
|
||||
@@ -1320,24 +1339,48 @@ class SupportsEagle3(SupportsEagleBase, Protocol):
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
"""
|
||||
Set which layers should output auxiliary
|
||||
hidden states for EAGLE-3.
|
||||
Set which layers should output auxiliary hidden states for EAGLE-3.
|
||||
|
||||
Args:
|
||||
layers: Tuple of layer indices that should output auxiliary
|
||||
hidden states.
|
||||
"""
|
||||
...
|
||||
parent_ref = self
|
||||
if hasattr(self, "get_language_model"):
|
||||
parent_ref = self.get_language_model()
|
||||
elif hasattr(self, "language_model"):
|
||||
parent_ref = self.language_model
|
||||
assert hasattr(parent_ref, "model"), (
|
||||
"Model instance must have 'model' attribute to set number of layers"
|
||||
)
|
||||
assert isinstance(parent_ref.model, EagleModelMixin), (
|
||||
"Model instance must inherit from EagleModelMixin to set auxiliary layers"
|
||||
)
|
||||
parent_ref.model._set_aux_hidden_state_layers(layers)
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
"""
|
||||
Get the layer indices that should output auxiliary hidden states
|
||||
for EAGLE-3.
|
||||
Get the default layer indices that should output auxiliary hidden states
|
||||
for EAGLE-3 for this model. Models can override this method to provide
|
||||
different default layers based on their architecture, but it is encouraged
|
||||
to instead include the layer specification in the model's config if possible.
|
||||
|
||||
Returns:
|
||||
Tuple of layer indices for auxiliary hidden state outputs.
|
||||
"""
|
||||
...
|
||||
parent_ref = self
|
||||
if hasattr(self, "get_language_model"):
|
||||
parent_ref = self.get_language_model()
|
||||
elif hasattr(self, "language_model"):
|
||||
parent_ref = self.language_model
|
||||
assert hasattr(parent_ref, "model"), (
|
||||
"Model instance must have 'model' attribute to get number of layers"
|
||||
)
|
||||
assert hasattr(parent_ref.model, "layers"), (
|
||||
"Model instance must have 'layers' attribute to get number of layers"
|
||||
)
|
||||
num_layers = len(parent_ref.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
|
||||
@overload
|
||||
|
||||
@@ -61,6 +61,7 @@ from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .adapters import as_embedding_model, as_seq_cls_model
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
@@ -351,7 +352,7 @@ def llama_model_invariants(
|
||||
# mark_unbacked_dims={"input_ids": 0},
|
||||
shape_invariants=llama_model_invariants
|
||||
)
|
||||
class LlamaModel(nn.Module):
|
||||
class LlamaModel(nn.Module, EagleModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -389,8 +390,6 @@ class LlamaModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
@@ -417,15 +416,16 @@ class LlamaModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual, **extra_layer_kwargs
|
||||
)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -556,18 +556,6 @@ class LlamaForCausalLM(
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
"""Override to return default layers for Llama
|
||||
|
||||
Note: The GPU model runner will override this with layers from
|
||||
the speculative config if available, providing dynamic configuration.
|
||||
"""
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@@ -55,6 +55,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
@@ -503,7 +504,12 @@ def init_vision_tower_for_llava(
|
||||
dummy_inputs=LlavaDummyInputsBuilder,
|
||||
)
|
||||
class LlavaForConditionalGeneration(
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||
nn.Module,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@@ -527,13 +533,6 @@ class LlavaForConditionalGeneration(
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.get_language_model().model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.get_language_model().model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -682,13 +682,6 @@ class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
|
||||
@@ -63,7 +63,13 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
@@ -391,7 +397,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MiniCPMModel(nn.Module):
|
||||
class MiniCPMModel(nn.Module, EagleModelMixin):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -413,8 +419,6 @@ class MiniCPMModel(nn.Module):
|
||||
self._init_layers(prefix, config, cache_config, quant_config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], self.config.hidden_size
|
||||
)
|
||||
@@ -455,19 +459,18 @@ class MiniCPMModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -550,7 +553,9 @@ class MiniCPMModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
class MiniCPMForCausalLM(
|
||||
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -611,13 +616,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
@@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
@@ -409,7 +410,12 @@ def init_vision_tower_for_llava(
|
||||
dummy_inputs=Mistral3DummyInputsBuilder,
|
||||
)
|
||||
class Mistral3ForConditionalGeneration(
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||
nn.Module,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@@ -433,13 +439,6 @@ class Mistral3ForConditionalGeneration(
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.get_language_model().model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.get_language_model().model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -798,20 +798,16 @@ class Llama4ForConditionalGeneration(
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
"""Set which layers should output auxiliary hidden states for EAGLE3."""
|
||||
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
|
||||
self.language_model.set_aux_hidden_state_layers(layers)
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
"""Get the layer indices for auxiliary hidden state outputs.
|
||||
|
||||
Note: The GPU model runner will override this with layers from
|
||||
the speculative config if available, providing dynamic configuration.
|
||||
"""
|
||||
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
|
||||
return self.language_model.get_eagle3_aux_hidden_state_layers()
|
||||
assert hasattr(
|
||||
self.language_model, "get_eagle3_default_aux_hidden_state_layers"
|
||||
)
|
||||
return self.language_model.get_eagle3_default_aux_hidden_state_layers()
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
|
||||
@@ -62,7 +62,13 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -349,7 +355,7 @@ def qwen_2_model_invariants(
|
||||
},
|
||||
shape_invariants=qwen_2_model_invariants,
|
||||
)
|
||||
class Qwen2Model(nn.Module):
|
||||
class Qwen2Model(nn.Module, EagleModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -410,8 +416,6 @@ class Qwen2Model(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
@@ -433,13 +437,14 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -519,7 +524,9 @@ class Qwen2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
class Qwen2ForCausalLM(
|
||||
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -566,13 +573,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
@@ -89,6 +89,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
@@ -1000,6 +1001,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsMultiModalPruning,
|
||||
SupportsMRoPE,
|
||||
@@ -1143,13 +1145,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.language_model.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.language_model.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2_5_VLImageInputs | None:
|
||||
|
||||
@@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
|
||||
@@ -258,7 +258,9 @@ class Qwen3Model(Qwen2Model):
|
||||
)
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
class Qwen3ForCausalLM(
|
||||
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -307,13 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
|
||||
@@ -65,7 +65,14 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .interfaces import (
|
||||
EagleModelMixin,
|
||||
MixtureOfExperts,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -427,7 +434,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Qwen3MoeModel(nn.Module):
|
||||
class Qwen3MoeModel(nn.Module, EagleModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -461,8 +468,6 @@ class Qwen3MoeModel(nn.Module):
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
# Track layers for auxiliary hidden state outputs (EAGLE3)
|
||||
self.aux_hidden_state_layers: tuple[int, ...] = ()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -485,18 +490,17 @@ class Qwen3MoeModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state(
|
||||
[], self.start_layer, hidden_states, residual
|
||||
)
|
||||
for layer_idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer),
|
||||
start=self.start_layer,
|
||||
):
|
||||
# Collect auxiliary hidden states if specified
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_state = (
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
aux_hidden_states.append(aux_hidden_state)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, layer_idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -666,7 +670,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLM(
|
||||
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
|
||||
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle, SupportsEagle3, MixtureOfExperts
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@@ -751,13 +755,6 @@ class Qwen3MoeForCausalLM(
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
|
||||
@@ -101,6 +101,7 @@ from vllm.utils.math_utils import round_up
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
@@ -1275,13 +1276,10 @@ class Qwen3LLMModel(Qwen3Model):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for layer_idx, layer in islice(
|
||||
enumerate(self.layers), self.start_layer, self.end_layer
|
||||
):
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
@@ -1295,6 +1293,9 @@ class Qwen3LLMModel(Qwen3Model):
|
||||
hidden_states
|
||||
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
|
||||
)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, layer_idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -1351,6 +1352,7 @@ class Qwen3VLForConditionalGeneration(
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsMRoPE,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsMultiModalPruning,
|
||||
):
|
||||
@@ -1449,13 +1451,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.language_model.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.language_model.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _get_deepstack_input_embeds(
|
||||
self,
|
||||
num_tokens: int,
|
||||
|
||||
@@ -102,19 +102,17 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state(
|
||||
[], self.start_layer, hidden_states, residual
|
||||
)
|
||||
for layer_idx, layer in islice(
|
||||
enumerate(self.layers), self.start_layer, self.end_layer
|
||||
):
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
if deepstack_input_embeds is not None and layer_idx in range(
|
||||
0, len(deepstack_input_embeds)
|
||||
):
|
||||
@@ -123,6 +121,10 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
|
||||
)
|
||||
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, layer_idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
|
||||
@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
EagleModelMixin,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsPP,
|
||||
)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class StepDecoderModel(nn.Module):
|
||||
class StepDecoderModel(nn.Module, EagleModelMixin):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers: tuple[int, ...] = getattr(
|
||||
config, "aux_hidden_state_layers", ()
|
||||
)
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"],
|
||||
config.hidden_size,
|
||||
@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
if residual is None:
|
||||
aux_hidden_states.append(hidden_states)
|
||||
else:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step1ForCausalLM(nn.Module, SupportsPP):
|
||||
class Step1ForCausalLM(nn.Module, SupportsPP, SupportsEagle, SupportsEagle3):
|
||||
packed_modules_mapping = STEP_PACKED_MODULES_MAPPING
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@@ -618,6 +618,6 @@ class Base(
|
||||
# Ensure that the capture hooks are installed before dynamo traces the model
|
||||
maybe_install_capturing_hooks(self.model)
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = self.text_config.num_hidden_layers
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
Reference in New Issue
Block a user