[Bugfix] Fix some issues with MoERunner PR #32344 (#34371)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2026-02-11 17:33:14 -05:00
committed by GitHub
parent 5aff2699bd
commit 31d992d215
2 changed files with 6 additions and 3 deletions

View File

@@ -700,7 +700,7 @@ class FusedMoE(CustomOp):
@property
def gate(self) -> torch.nn.Module | None:
return self._gate
return self._gate if self.use_overlapped else None
@property
def tp_size(self):
@@ -725,7 +725,7 @@ class FusedMoE(CustomOp):
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return self._gate is not None
return self.gate is not None
def _maybe_init_expert_routing_tables(
self,
@@ -1457,7 +1457,6 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
self.ensure_moe_quant_config_init()
return self.runner.forward(
hidden_states,
router_logits,

View File

@@ -63,6 +63,8 @@ def _moe_forward(
layer_name: str,
) -> torch.Tensor:
layer = get_layer_from_name(layer_name)
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)
@@ -84,6 +86,8 @@ def _moe_forward_shared(
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(layer_name)
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
layer, hidden_states, router_logits, shared_experts_input
)