Add EAGLE3 support for AFMoE (#33111)
Signed-off-by: AutumnAurelium <88015631+AutumnAurelium@users.noreply.github.com>
This commit is contained in:
@@ -682,6 +682,7 @@ class SpeculativeConfig:
|
|||||||
"gpt_oss",
|
"gpt_oss",
|
||||||
"hunyuan_vl",
|
"hunyuan_vl",
|
||||||
"hunyuan_v1_dense",
|
"hunyuan_v1_dense",
|
||||||
|
"afmoe",
|
||||||
]
|
]
|
||||||
if (
|
if (
|
||||||
self.method == "eagle3"
|
self.method == "eagle3"
|
||||||
|
|||||||
@@ -36,7 +36,11 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
from vllm.model_executor.models.interfaces import (
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP
|
from vllm.model_executor.models.llama import LlamaMLP as AfmoeMLP
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
@@ -416,6 +420,8 @@ class AfmoeModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size
|
["hidden_states", "residual"], config.hidden_size
|
||||||
)
|
)
|
||||||
@@ -429,7 +435,7 @@ class AfmoeModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
inputs_embeds: torch.Tensor | None = None,
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | IntermediateTensors:
|
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@@ -446,7 +452,14 @@ class AfmoeModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
aux_hidden_states = []
|
||||||
|
for idx, layer in enumerate(
|
||||||
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
|
):
|
||||||
|
if idx in self.aux_hidden_state_layers:
|
||||||
|
aux_hidden_states.append(
|
||||||
|
hidden_states + residual if residual is not None else hidden_states
|
||||||
|
)
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@@ -455,6 +468,10 @@ class AfmoeModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
if len(aux_hidden_states) > 0:
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def make_empty_intermediate_tensors(
|
def make_empty_intermediate_tensors(
|
||||||
@@ -586,7 +603,7 @@ class AfmoeModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@@ -673,13 +690,20 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
|
self.model.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
|
num_layers = len(self.model.layers)
|
||||||
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor | None,
|
input_ids: torch.Tensor | None,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: IntermediateTensors | None = None,
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
inputs_embeds: torch.Tensor | None = None,
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | IntermediateTensors:
|
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user