[torch.compile] support moe models (#9632)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -7,12 +7,11 @@ import torch
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
|
||||
torch_moe, torch_moe_single)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
@@ -193,7 +192,7 @@ def test_fused_marlin_moe(
|
||||
topk,
|
||||
renormalize=False,
|
||||
)
|
||||
marlin_output = fused_marlin_moe(
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
@@ -309,7 +308,7 @@ def test_single_marlin_moe_multiply(
|
||||
sort_indices = stack_and_dev(sort_indices_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
marlin_output = single_marlin_moe(
|
||||
marlin_output = torch.ops.vllm.single_marlin_moe(
|
||||
a,
|
||||
qweight,
|
||||
scales,
|
||||
|
||||
Reference in New Issue
Block a user