Add: Eagle3 support for Qwen3.5 (#36658)

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli
2026-03-11 15:37:42 +05:30
committed by GitHub
parent 646b85544b
commit 9d07a3d6e4
2 changed files with 25 additions and 2 deletions

View File

@@ -75,6 +75,7 @@ from .interfaces import (
IsHybrid,
MixtureOfExperts,
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
_require_is_multimodal,
@@ -353,6 +354,8 @@ class Qwen3_5Model(Qwen3NextModel):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = ()
def load_fused_expert_weights(
self,
name: str,
@@ -536,6 +539,7 @@ class Qwen3_5Model(Qwen3NextModel):
class Qwen3_5ForCausalLMBase(
nn.Module,
HasInnerState,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
):
@@ -592,6 +596,13 @@ class Qwen3_5ForCausalLMBase(
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -1148,6 +1148,8 @@ class Qwen3NextModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -1157,7 +1159,7 @@ class Qwen3NextModel(nn.Module):
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
@@ -1169,7 +1171,15 @@ class Qwen3NextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
aux_hidden_states = []
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
@@ -1181,6 +1191,8 @@ class Qwen3NextModel(nn.Module):
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if aux_hidden_states:
return hidden_states, aux_hidden_states
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: