[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -1,164 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test AWQ with fused MoE Marlin kernels.
|
||||
|
||||
Run `pytest tests/kernels/test_awq_marlin.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [2, 6]
|
||||
GROUP_SIZES = [-1, 32, 128]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.skipif(not (ops.supports_moe_ops
|
||||
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
def test_fused_marlin_moe_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref1_l = []
|
||||
qweights1_l = []
|
||||
scales1_l = []
|
||||
zp1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref1_l.append(w_ref1)
|
||||
qweights1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zp1_l.append(zp1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweights1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
zp1 = stack_and_dev(zp1_l)
|
||||
|
||||
w_ref2_l = []
|
||||
qweights2_l = []
|
||||
scales2_l = []
|
||||
zp2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref2_l.append(w_ref2)
|
||||
qweights2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zp2_l.append(zp2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweights2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
zp2 = stack_and_dev(zp2_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, False)
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_zeros=zp1,
|
||||
w2_zeros=zp2,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
|
||||
score, topk, None)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
def test_single_marlin_moe_multiply_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweights_l = []
|
||||
scales_l = []
|
||||
zp_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
w_ref, qweight, scales, zp = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref_l.append(w_ref)
|
||||
qweights_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zp_l.append(zp)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweights_l).contiguous()
|
||||
scales = stack_and_dev(scales_l).contiguous()
|
||||
zp = stack_and_dev(zp_l).contiguous()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
marlin_output = torch.ops.vllm.single_marlin_moe(a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
w_zeros=zp,
|
||||
num_bits=num_bits)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
@@ -18,9 +18,10 @@ from vllm.model_executor.layers.quantization.qqq import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_permute_scales, query_marlin_supported_quant_types)
|
||||
marlin_make_workspace_new, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
pack_fp8_to_int32)
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||
marlin_weights)
|
||||
@@ -73,7 +74,7 @@ def rand_data(shape, dtype=torch.float16):
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@@ -138,7 +139,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
@@ -220,38 +221,50 @@ def test_gptq_marlin_gemm(
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||
b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace.scratch, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, False,
|
||||
use_atomic_add, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
@@ -326,80 +339,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", [8])
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_fp8_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
# WEIGHTS
|
||||
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device="cuda"),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=-1)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.fp8_marlin_gemm,
|
||||
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
|
||||
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=a_input,
|
||||
b_q_weight=marlin_qweight,
|
||||
b_scales=marlin_scales,
|
||||
workspace=workspace.scratch,
|
||||
num_bits=num_bits,
|
||||
size_m=a_input.shape[0],
|
||||
size_n=b_weight.shape[1],
|
||||
size_k=a_input.shape[1],
|
||||
)
|
||||
output_ref = torch.matmul(a_input, b_weight)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@@ -432,25 +371,23 @@ def test_awq_marlin_gemm(
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
is_k_full = True
|
||||
has_zp = True
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
@@ -508,23 +445,22 @@ def test_hqq_marlin_gemm(
|
||||
g_idx = marlin_make_empty_g_idx(dev)
|
||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(b_weight.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_w_q,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[0],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=True,
|
||||
)
|
||||
@@ -621,23 +557,22 @@ def test_marlin_gemm_subset_input():
|
||||
b_weight, quant_type, group_size, False)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
has_zp=False,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
|
||||
Reference in New Issue
Block a user