From f3888aca83edf5125c37d0983f1df94df27d9920 Mon Sep 17 00:00:00 2001 From: AutumnAurelium <88015631+AutumnAurelium@users.noreply.github.com> Date: Fri, 30 Jan 2026 22:53:08 -0800 Subject: [PATCH] Add EAGLE3 support for AFMoE (#33111) Signed-off-by: AutumnAurelium <88015631+AutumnAurelium@users.noreply.github.com> --- vllm/config/speculative.py | 1 + vllm/model_executor/models/afmoe.py | 34 ++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 29dbab9f7..f3de1e171 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -682,6 +682,7 @@ class SpeculativeConfig: "gpt_oss", "hunyuan_vl", "hunyuan_v1_dense", + "afmoe", ] if ( self.method == "eagle3" diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py index a47fe4b7b..6f5d7d766 100644 --- a/vllm/model_executor/models/afmoe.py +++ b/vllm/model_executor/models/afmoe.py @@ -36,7 +36,11 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.interfaces import ( + SupportsEagle3, + SupportsLoRA, + SupportsPP, +) from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP from vllm.model_executor.models.utils import ( AutoWeightsLoader, @@ -416,6 +420,8 @@ class AfmoeModel(nn.Module): else: self.norm = PPMissingLayer() + self.aux_hidden_state_layers = tuple[int, ...]() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) @@ -429,7 +435,7 @@ class AfmoeModel(nn.Module): positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -446,7 +452,14 @@ class AfmoeModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append( + hidden_states + residual if residual is not None else hidden_states + ) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: @@ -455,6 +468,10 @@ class AfmoeModel(nn.Module): ) hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states def make_empty_intermediate_tensors( @@ -586,7 +603,7 @@ class AfmoeModel(nn.Module): return loaded_params -class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): +class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -673,13 +690,20 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds )