[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)
Co-authored-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
@@ -2,16 +2,14 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
|
||||
torch_moe, torch_moe_single)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
@@ -24,37 +22,6 @@ from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
_, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@@ -127,20 +94,10 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@@ -159,9 +116,6 @@ def test_fused_marlin_moe(
|
||||
):
|
||||
seed_everything(7)
|
||||
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
@@ -241,15 +195,15 @@ def test_fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
g_idx1,
|
||||
g_idx2,
|
||||
sort_indices1,
|
||||
sort_indices2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=scales1,
|
||||
w2_scale=scales2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
@@ -280,9 +234,13 @@ def test_fused_marlin_moe(
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
zp = torch.empty((0, 0),
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
opcheck(torch.ops._moe_C.marlin_gemm_moe,
|
||||
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
|
||||
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
|
||||
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
|
||||
2 * n, k, True, e, topk, block_size_m, True, False))
|
||||
|
||||
|
||||
@@ -291,7 +249,7 @@ def test_fused_marlin_moe(
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@@ -308,8 +266,6 @@ def test_single_marlin_moe_multiply(
|
||||
num_bits: int,
|
||||
is_k_full: bool,
|
||||
):
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
@@ -355,13 +311,14 @@ def test_single_marlin_moe_multiply(
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
topk,
|
||||
renormalize=False,
|
||||
g_idx=g_idx,
|
||||
sort_indices=sort_indices,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user