[Bugfix] Define router_logits_dtype for remaining MoE models (#33737)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -142,6 +142,7 @@ class AfmoeMoE(nn.Module):
|
|||||||
e_score_correction_bias=self.expert_bias,
|
e_score_correction_bias=self.expert_bias,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
|
router_logits_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -300,6 +300,7 @@ class BailingMoE(nn.Module):
|
|||||||
num_expert_group=self.n_group,
|
num_expert_group=self.n_group,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
router_logits_dtype=self.router_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -71,7 +71,6 @@ class FlexOlmoMoE(nn.Module):
|
|||||||
prefix=f"{prefix}.gate",
|
prefix=f"{prefix}.gate",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
|
||||||
self.experts = FusedMoE(
|
self.experts = FusedMoE(
|
||||||
num_experts=hf_config.num_experts,
|
num_experts=hf_config.num_experts,
|
||||||
top_k=hf_config.num_experts_per_tok,
|
top_k=hf_config.num_experts_per_tok,
|
||||||
@@ -82,6 +81,7 @@ class FlexOlmoMoE(nn.Module):
|
|||||||
quant_config=None,
|
quant_config=None,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
|
router_logits_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.top_k = hf_config.num_experts_per_tok
|
self.top_k = hf_config.num_experts_per_tok
|
||||||
|
|||||||
@@ -236,9 +236,9 @@ class FlashMLP(nn.Module):
|
|||||||
class LongcatRouter(nn.Module):
|
class LongcatRouter(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config: FlashConfig,
|
||||||
zero_expert_num=0,
|
zero_expert_num: int,
|
||||||
rounter_params_dtype=torch.bfloat16,
|
rounter_params_dtype: torch.dtype,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -309,6 +309,7 @@ class LongcatMoe(nn.Module):
|
|||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
enable_eplb=enable_eplb,
|
enable_eplb=enable_eplb,
|
||||||
routed_scaling_factor=config.routed_scaling_factor,
|
routed_scaling_factor=config.routed_scaling_factor,
|
||||||
|
router_logits_dtype=self.rounter_params_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ class MiMoV2MoE(nn.Module):
|
|||||||
num_expert_group=config.n_group,
|
num_expert_group=config.n_group,
|
||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
scoring_func="sigmoid",
|
scoring_func="sigmoid",
|
||||||
|
router_logits_dtype=self.gate_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -388,6 +388,7 @@ class FusedMoEBlock(nn.Module):
|
|||||||
routed_scaling_factor=config.moe_router_scaling_factor,
|
routed_scaling_factor=config.moe_router_scaling_factor,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
|
router_logits_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user