Add EAGLE3 support for AFMoE (#33111)

Signed-off-by: AutumnAurelium <88015631+AutumnAurelium@users.noreply.github.com>
This commit is contained in:
AutumnAurelium
2026-01-30 22:53:08 -08:00
committed by GitHub
parent f0bca83ee4
commit f3888aca83
2 changed files with 30 additions and 5 deletions

View File

@@ -682,6 +682,7 @@ class SpeculativeConfig:
"gpt_oss", "gpt_oss",
"hunyuan_vl", "hunyuan_vl",
"hunyuan_v1_dense", "hunyuan_v1_dense",
"afmoe",
] ]
if ( if (
self.method == "eagle3" self.method == "eagle3"

View File

@@ -36,7 +36,11 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.interfaces import (
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
AutoWeightsLoader, AutoWeightsLoader,
@@ -416,6 +420,8 @@ class AfmoeModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
@@ -429,7 +435,7 @@ class AfmoeModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@@ -446,7 +452,14 @@ class AfmoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer): aux_hidden_states = []
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@@ -455,6 +468,10 @@ class AfmoeModel(nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
@@ -586,7 +603,7 @@ class AfmoeModel(nn.Module):
return loaded_params return loaded_params
class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -673,13 +690,20 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds input_ids, positions, intermediate_tensors, inputs_embeds
) )