[torch.compile] support moe models (#9632)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-27 21:58:04 -07:00
committed by GitHub
parent 4e2d95e372
commit 32176fee73
12 changed files with 216 additions and 77 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
@@ -435,10 +436,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -449,7 +446,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_marlin_moe(
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,