[Misc] Fused MoE Marlin support for GPTQ (#8217)
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import MixtralConfig
|
||||
@@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
@@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
_, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@@ -43,11 +65,11 @@ def test_fused_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||
@@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
|
||||
quant_type = scalar_types.uint4b8
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
for i in range(w2.shape[0]):
|
||||
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
|
||||
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||
test_perm)
|
||||
w_ref1_l.append(w_ref1)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
g_idx1 = stack_and_dev(g_idx1_l)
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l)
|
||||
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||
test_perm)
|
||||
w_ref2_l.append(w_ref2)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
g_idx2 = stack_and_dev(g_idx2_l)
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
|
||||
triton_output = fused_moe(
|
||||
a,
|
||||
w_ref1.transpose(1, 2).contiguous(),
|
||||
w_ref2.transpose(1, 2).contiguous(),
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
)
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
score,
|
||||
g_idx1,
|
||||
g_idx2,
|
||||
sort_indices1,
|
||||
sort_indices2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=scales1,
|
||||
w2_scale=scales2,
|
||||
)
|
||||
|
||||
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
def test_marlin_moe_mmm(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
):
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == k:
|
||||
return
|
||||
|
||||
quant_type = scalar_types.uint4b8
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweights_l = []
|
||||
scales_l = []
|
||||
g_idx_l = []
|
||||
sort_indices_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
|
||||
w_ref_l.append(w_ref)
|
||||
qweights_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
g_idx_l.append(g_idx)
|
||||
sort_indices_l.append(sort_indices)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweights_l).contiguous()
|
||||
scales = stack_and_dev(scales_l)
|
||||
g_idx = stack_and_dev(g_idx_l)
|
||||
sort_indices = stack_and_dev(sort_indices_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
marlin_output = single_marlin_moe(a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
topk,
|
||||
renormalize=False)
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user