Fix MoE for the Transformers modelling backend (#34436)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-02-12 18:29:42 +01:00
committed by GitHub
parent f2c47886fd
commit 679ca5d8d3

View File

@@ -45,7 +45,6 @@ class TransformersFusedMoE(FusedMoE):
# --8<-- [end:transformers_fused_moe]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._topk_ids: torch.Tensor = None
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
@@ -63,7 +62,8 @@ class TransformersFusedMoE(FusedMoE):
(topk_ids,) = dist_group.all_gatherv([topk_ids], 0, sizes)
return topk_weights, topk_ids
self.custom_routing_function = custom_routing_function
kwargs["custom_routing_function"] = custom_routing_function
super().__init__(*args, **kwargs)
def forward(
self,
@@ -94,7 +94,7 @@ def transformers_moe_forward(
self = forward_context.no_compile_layers[layer_name]
self._topk_ids = topk_ids
# Clone hidden_states because it will be mutated in-place in FusedMoE
return self.forward_impl(hidden_states.clone(), topk_weights)
return self.runner.forward(hidden_states.clone(), topk_weights)
def transformers_moe_forward_fake(