feature: support eagle3 for HunyuanVL & Hunyuan (#33035)

Signed-off-by: irisliu10 <601012173@qq.com>
Signed-off-by: Iris <38269816+irisliu10@users.noreply.github.com>
This commit is contained in:
Iris
2026-01-28 01:55:48 +08:00
committed by GitHub
parent a6760f1525
commit bd92089d33
4 changed files with 49 additions and 3 deletions

View File

@@ -675,7 +675,14 @@ class SpeculativeConfig:
f"{self.disable_by_batch_size=}"
)
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
eagle3_target_supported = [
"llama",
"qwen",
"minicpm",
"gpt_oss",
"hunyuan_vl",
"hunyuan_v1_dense",
]
if (
self.method == "eagle3"
and self.target_model_config

View File

@@ -66,7 +66,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -630,6 +630,7 @@ class HunYuanModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -654,9 +655,13 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
aux_hidden_states = []
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual, kv_states = layer(
positions,
hidden_states,
@@ -675,6 +680,9 @@ class HunYuanModel(nn.Module):
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def _split_qkv_weight(self, qkv: torch.Tensor):
@@ -897,7 +905,7 @@ class HunYuanModel(nn.Module):
return loaded_params
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -936,6 +944,13 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -83,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
@@ -780,6 +781,7 @@ class HunYuanVLForConditionalGeneration(
SupportsPP,
SupportsQuant,
SupportsXDRoPE,
SupportsEagle3,
):
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@@ -966,6 +968,13 @@ class HunYuanVLForConditionalGeneration(
multimodal_embeddings += tuple(image_embeddings)
return multimodal_embeddings
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor | None,

View File

@@ -115,6 +115,8 @@ class SpecDecodeBaseProposer:
# Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal
self.uses_mrope = self.draft_model_config.uses_mrope
self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
@@ -129,6 +131,12 @@ class SpecDecodeBaseProposer:
self.mrope_positions = torch.zeros(
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
)
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions = torch.zeros(
(self.uses_xdrope_dim, self.max_num_tokens + 1),
dtype=torch.int64,
device=device,
)
else:
# RoPE need (max_num_tokens,)
self.positions = torch.zeros(
@@ -221,11 +229,15 @@ class SpecDecodeBaseProposer:
def _get_positions(self, num_tokens: int):
if self.uses_mrope:
return self.mrope_positions[:, :num_tokens]
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
return self.xdrope_positions[:, :num_tokens]
return self.positions[:num_tokens]
def _set_positions(self, num_tokens: int, positions: torch.Tensor):
if self.uses_mrope:
self.mrope_positions[:, :num_tokens] = positions
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions[:, :num_tokens] = positions
else:
# Convert M-RoPE positions if target model uses M-RoPE
# but draft doesn't, For text inputs, all M-RoPE
@@ -623,6 +635,8 @@ class SpecDecodeBaseProposer:
self.input_ids[last_token_indices] = next_token_ids
# copy inputs to buffer for cudagraph
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
target_positions = target_positions[0]
self._set_positions(num_tokens, target_positions)
return num_tokens, last_token_indices, cad
@@ -1126,6 +1140,7 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"HunYuanVLForConditionalGeneration",
"GlmOcrForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id