Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -700,7 +700,7 @@ class FusedMoE(CustomOp):
|
||||
|
||||
@property
|
||||
def gate(self) -> torch.nn.Module | None:
|
||||
return self._gate
|
||||
return self._gate if self.use_overlapped else None
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
@@ -725,7 +725,7 @@ class FusedMoE(CustomOp):
|
||||
@property
|
||||
def is_internal_router(self) -> bool:
|
||||
# By default, router/gate is called before FusedMoE forward pass
|
||||
return self._gate is not None
|
||||
return self.gate is not None
|
||||
|
||||
def _maybe_init_expert_routing_tables(
|
||||
self,
|
||||
@@ -1457,7 +1457,6 @@ class FusedMoE(CustomOp):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
self.ensure_moe_quant_config_init()
|
||||
return self.runner.forward(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
|
||||
@@ -63,6 +63,8 @@ def _moe_forward(
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
layer, hidden_states, router_logits, shared_experts_input
|
||||
)
|
||||
@@ -84,6 +86,8 @@ def _moe_forward_shared(
|
||||
layer_name: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
layer, hidden_states, router_logits, shared_experts_input
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user