[Model] Support Grok1 (#13795)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-02-25 20:07:12 -05:00
committed by GitHub
parent 34e3494e70
commit 07c4353057
11 changed files with 634 additions and 17 deletions

View File

@@ -120,7 +120,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
@@ -134,7 +135,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
activation=activation)
def forward_cuda(
self,
@@ -150,7 +152,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
@@ -170,6 +173,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map)
@@ -186,9 +190,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
**kwargs,
):
assert custom_routing_function is None
assert activation == "silu", f"{activation} is not supported."
return layer.ipex_fusion(
x,
use_grouped_topk,
@@ -213,7 +219,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
@@ -225,6 +232,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU.")
assert activation == "silu", f"{activation} is not supported for TPU."
return fused_moe_pallas(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -277,6 +285,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
super().__init__()
@@ -305,6 +314,7 @@ class FusedMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.activation = activation
self.expert_map = None
if self.ep_size > 1:
@@ -653,7 +663,9 @@ class FusedMoE(torch.nn.Module):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias)
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)