[Misc][Refactor] Add FusedMoERouter object (#30519)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
@@ -997,6 +998,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
router: FusedMoERouter,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -1051,7 +1053,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
topk_weights, topk_ids = router.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user