Fix MoE for the Transformers modelling backend (#34436)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -45,7 +45,6 @@ class TransformersFusedMoE(FusedMoE):
|
||||
# --8<-- [end:transformers_fused_moe]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._topk_ids: torch.Tensor = None
|
||||
|
||||
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
|
||||
@@ -63,7 +62,8 @@ class TransformersFusedMoE(FusedMoE):
|
||||
(topk_ids,) = dist_group.all_gatherv([topk_ids], 0, sizes)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
self.custom_routing_function = custom_routing_function
|
||||
kwargs["custom_routing_function"] = custom_routing_function
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -94,7 +94,7 @@ def transformers_moe_forward(
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._topk_ids = topk_ids
|
||||
# Clone hidden_states because it will be mutated in-place in FusedMoE
|
||||
return self.forward_impl(hidden_states.clone(), topk_weights)
|
||||
return self.runner.forward(hidden_states.clone(), topk_weights)
|
||||
|
||||
|
||||
def transformers_moe_forward_fake(
|
||||
|
||||
Reference in New Issue
Block a user