[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -11,19 +11,20 @@ from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
@@ -285,7 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
||||
@pytest.mark.parametrize("m", [1, 123, 666])
|
||||
@pytest.mark.parametrize("n", [128, 1024])
|
||||
@pytest.mark.parametrize("k", [256, 2048])
|
||||
@pytest.mark.parametrize("e", [4, 12])
|
||||
@@ -294,8 +295,10 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("quant_type", [
|
||||
scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8,
|
||||
scalar_types.float8_e4m3fn
|
||||
])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
@@ -308,14 +311,22 @@ def test_fused_marlin_moe(
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
has_zp: bool,
|
||||
quant_type: ScalarType,
|
||||
is_k_full: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
torch.cuda.manual_seed(0)
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
return
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
@@ -326,17 +337,9 @@ def test_fused_marlin_moe(
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if has_zp:
|
||||
# we don't build kernel for int8 with zero
|
||||
if num_bits == 8:
|
||||
return
|
||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
quant_type = scalar_types.uint4b8 \
|
||||
if num_bits == 4 else scalar_types.uint8b128
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
@@ -364,17 +367,23 @@ def test_fused_marlin_moe(
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zeros1_l.append(zeros1)
|
||||
else:
|
||||
elif quant_type != scalar_types.float8_e4m3fn:
|
||||
test_perm = torch.randperm(k)
|
||||
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
|
||||
marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
else:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
|
||||
w1[i], group_size)
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
@@ -399,17 +408,23 @@ def test_fused_marlin_moe(
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zeros2_l.append(zeros2)
|
||||
else:
|
||||
elif quant_type != scalar_types.float8_e4m3fn:
|
||||
test_perm = torch.randperm(n)
|
||||
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
|
||||
marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
else:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
|
||||
w2[i], group_size)
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
@@ -442,102 +457,10 @@ def test_fused_marlin_moe(
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
num_bits=num_bits,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
||||
@pytest.mark.parametrize("n", [128, 1024])
|
||||
@pytest.mark.parametrize("k", [256, 2048])
|
||||
@pytest.mark.parametrize("e", [4, 12])
|
||||
@pytest.mark.parametrize("topk", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype, group_size: int,
|
||||
act_order: bool, num_bits: int,
|
||||
has_zp: bool, is_k_full: bool):
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if has_zp:
|
||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
quant_type = scalar_types.uint4b8 \
|
||||
if num_bits == 4 else scalar_types.uint8b128
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweight_l = []
|
||||
scales_l = []
|
||||
zeros_l = []
|
||||
g_idx_l = []
|
||||
sort_indices_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
if has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zeros_l.append(zeros)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||
test_perm)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
g_idx_l.append(g_idx)
|
||||
sort_indices_l.append(sort_indices)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweight_l).contiguous()
|
||||
scales = stack_and_dev(scales_l)
|
||||
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
|
||||
zeros = stack_and_dev(zeros_l) if zeros_l else None
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
marlin_output = torch.ops.vllm.single_marlin_moe(
|
||||
a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
g_idx=g_idx,
|
||||
sort_indices=sort_indices,
|
||||
w_zeros=zeros,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref, score, topk)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
|
||||
Reference in New Issue
Block a user