feature: support eagle3 for HunyuanVL & Hunyuan (#33035)
Signed-off-by: irisliu10 <601012173@qq.com> Signed-off-by: Iris <38269816+irisliu10@users.noreply.github.com>
This commit is contained in:
@@ -66,7 +66,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -630,6 +630,7 @@ class HunYuanModel(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -654,9 +655,13 @@ class HunYuanModel(nn.Module):
|
||||
|
||||
cla_factor = _get_cla_factor(self.config)
|
||||
prev_kv_states = None
|
||||
aux_hidden_states = []
|
||||
for i, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if i in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
|
||||
hidden_states, residual, kv_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
@@ -675,6 +680,9 @@ class HunYuanModel(nn.Module):
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def _split_qkv_weight(self, qkv: torch.Tensor):
|
||||
@@ -897,7 +905,7 @@ class HunYuanModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -936,6 +944,13 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
@@ -83,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
@@ -780,6 +781,7 @@ class HunYuanVLForConditionalGeneration(
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
SupportsXDRoPE,
|
||||
SupportsEagle3,
|
||||
):
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@@ -966,6 +968,13 @@ class HunYuanVLForConditionalGeneration(
|
||||
multimodal_embeddings += tuple(image_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.language_model.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.language_model.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
|
||||
Reference in New Issue
Block a user