Add: Eagle3 support for Qwen3.5 (#36658)
Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
This commit is contained in:
@@ -75,6 +75,7 @@ from .interfaces import (
|
||||
IsHybrid,
|
||||
MixtureOfExperts,
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
_require_is_multimodal,
|
||||
@@ -353,6 +354,8 @@ class Qwen3_5Model(Qwen3NextModel):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers: tuple[int, ...] = ()
|
||||
|
||||
def load_fused_expert_weights(
|
||||
self,
|
||||
name: str,
|
||||
@@ -536,6 +539,7 @@ class Qwen3_5Model(Qwen3NextModel):
|
||||
class Qwen3_5ForCausalLMBase(
|
||||
nn.Module,
|
||||
HasInnerState,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
):
|
||||
@@ -592,6 +596,13 @@ class Qwen3_5ForCausalLMBase(
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@@ -1148,6 +1148,8 @@ class Qwen3NextModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers: tuple[int, ...] = ()
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
@@ -1157,7 +1159,7 @@ class Qwen3NextModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@@ -1169,7 +1171,15 @@ class Qwen3NextModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
aux_hidden_states = []
|
||||
for layer_idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer),
|
||||
start=self.start_layer,
|
||||
):
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@@ -1181,6 +1191,8 @@ class Qwen3NextModel(nn.Module):
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if aux_hidden_states:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
Reference in New Issue
Block a user