[Refactor] Consolidate SupportsEagle (#36063)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-13 19:22:40 -04:00
committed by GitHub
parent 54a6db827f
commit 8b346309a5
24 changed files with 229 additions and 235 deletions

View File

@@ -65,7 +65,14 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -427,7 +434,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
@support_torch_compile
class Qwen3MoeModel(nn.Module):
class Qwen3MoeModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
@@ -461,8 +468,6 @@ class Qwen3MoeModel(nn.Module):
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
# Track layers for auxiliary hidden state outputs (EAGLE3)
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -485,18 +490,17 @@ class Qwen3MoeModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, hidden_states, residual
)
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
# Collect auxiliary hidden states if specified
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_state = (
hidden_states + residual if residual is not None else hidden_states
)
aux_hidden_states.append(aux_hidden_state)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -666,7 +670,7 @@ class Qwen3MoeModel(nn.Module):
class Qwen3MoeForCausalLM(
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle, SupportsEagle3, MixtureOfExperts
):
packed_modules_mapping = {
"qkv_proj": [
@@ -751,13 +755,6 @@ class Qwen3MoeForCausalLM(
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)