Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -700,7 +700,7 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def gate(self) -> torch.nn.Module | None:
|
def gate(self) -> torch.nn.Module | None:
|
||||||
return self._gate
|
return self._gate if self.use_overlapped else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
@@ -725,7 +725,7 @@ class FusedMoE(CustomOp):
|
|||||||
@property
|
@property
|
||||||
def is_internal_router(self) -> bool:
|
def is_internal_router(self) -> bool:
|
||||||
# By default, router/gate is called before FusedMoE forward pass
|
# By default, router/gate is called before FusedMoE forward pass
|
||||||
return self._gate is not None
|
return self.gate is not None
|
||||||
|
|
||||||
def _maybe_init_expert_routing_tables(
|
def _maybe_init_expert_routing_tables(
|
||||||
self,
|
self,
|
||||||
@@ -1457,7 +1457,6 @@ class FusedMoE(CustomOp):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
self.ensure_moe_quant_config_init()
|
|
||||||
return self.runner.forward(
|
return self.runner.forward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ def _moe_forward(
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
layer = get_layer_from_name(layer_name)
|
layer = get_layer_from_name(layer_name)
|
||||||
|
# TODO(bnell): this can be removed after MK migration is complete.
|
||||||
|
layer.ensure_moe_quant_config_init()
|
||||||
return layer.runner.forward_impl(
|
return layer.runner.forward_impl(
|
||||||
layer, hidden_states, router_logits, shared_experts_input
|
layer, hidden_states, router_logits, shared_experts_input
|
||||||
)
|
)
|
||||||
@@ -84,6 +86,8 @@ def _moe_forward_shared(
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
layer = get_layer_from_name(layer_name)
|
layer = get_layer_from_name(layer_name)
|
||||||
|
# TODO(bnell): this can be removed after MK migration is complete.
|
||||||
|
layer.ensure_moe_quant_config_init()
|
||||||
return layer.runner.forward_impl(
|
return layer.runner.forward_impl(
|
||||||
layer, hidden_states, router_logits, shared_experts_input
|
layer, hidden_states, router_logits, shared_experts_input
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user