feature: support eagle3 for HunyuanVL & Hunyuan (#33035)

Signed-off-by: irisliu10 <601012173@qq.com>
Signed-off-by: Iris <38269816+irisliu10@users.noreply.github.com>
This commit is contained in:
Iris
2026-01-28 01:55:48 +08:00
committed by GitHub
parent a6760f1525
commit bd92089d33
4 changed files with 49 additions and 3 deletions

View File

@@ -66,7 +66,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -630,6 +630,7 @@ class HunYuanModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -654,9 +655,13 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
aux_hidden_states = []
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual, kv_states = layer(
positions,
hidden_states,
@@ -675,6 +680,9 @@ class HunYuanModel(nn.Module):
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def _split_qkv_weight(self, qkv: torch.Tensor):
@@ -897,7 +905,7 @@ class HunYuanModel(nn.Module):
return loaded_params
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -936,6 +944,13 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -83,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
@@ -780,6 +781,7 @@ class HunYuanVLForConditionalGeneration(
SupportsPP,
SupportsQuant,
SupportsXDRoPE,
SupportsEagle3,
):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@@ -966,6 +968,13 @@ class HunYuanVLForConditionalGeneration(
multimodal_embeddings += tuple(image_embeddings)
return multimodal_embeddings
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,