[torch.compile] support moe models (#9632)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-27 21:58:04 -07:00
committed by GitHub
parent 4e2d95e372
commit 32176fee73
12 changed files with 216 additions and 77 deletions

View File

@@ -7,12 +7,11 @@ import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@@ -193,7 +192,7 @@ def test_fused_marlin_moe(
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
@@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(
marlin_output = torch.ops.vllm.single_marlin_moe(
a,
qweight,
scales,