[Hardware][Gaudi][Feature] Enable Dynamic MoE for Mixtral (#12303)

Signed-off-by: zhenwei <zhenweiliu@habana.ai>
This commit is contained in:
liuzhenwei
2025-03-25 00:48:40 +08:00
committed by GitHub
parent 3aee6573dc
commit 5eeadc2642
3 changed files with 57 additions and 2 deletions

View File

@@ -213,6 +213,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias,
)
def forward_hpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert layer is not None
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for HPU.")
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for HPU.")
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k)
def forward_tpu(
self,
layer: torch.nn.Module,
@@ -411,6 +439,9 @@ class FusedMoE(torch.nn.Module):
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.