[torch.compile] support moe models (#9632)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
@@ -435,10 +436,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@@ -449,7 +446,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function)
|
||||
|
||||
return fused_marlin_moe(
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
|
||||
Reference in New Issue
Block a user