[Kernel] some optimizations for dense marlin and moe marlin (#16850)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-06 00:39:30 +08:00
committed by GitHub
parent f62cad6431
commit 1d0c9d6b2d
26 changed files with 3512 additions and 3268 deletions

View File

@@ -18,9 +18,10 @@ from vllm.model_executor.layers.quantization.qqq import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_permute_scales, query_marlin_supported_quant_types)
marlin_make_workspace_new, marlin_permute_scales,
query_marlin_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32)
marlin_quant_fp8_torch)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
marlin_weights)
@@ -73,7 +74,7 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@@ -138,7 +139,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False))
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@@ -220,38 +221,50 @@ def test_gptq_marlin_gemm(
if group_size == size_k:
return
if size_k % group_size != 0:
return
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
if quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
b_weight.T, group_size)
g_idx = None
sort_indices = None
else:
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, False,
use_atomic_add, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
@@ -326,80 +339,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
# WEIGHTS
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device="cuda"),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=size_k,
size_n=size_n,
group_size=-1)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
opcheck(torch.ops._C.fp8_marlin_gemm,
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
b_scales=marlin_scales,
workspace=workspace.scratch,
num_bits=num_bits,
size_m=a_input.shape[0],
size_n=b_weight.shape[1],
size_k=a_input.shape[1],
)
output_ref = torch.matmul(a_input, b_weight)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@@ -432,25 +371,23 @@ def test_awq_marlin_gemm(
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
is_k_full = True
has_zp = True
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
@@ -508,23 +445,22 @@ def test_hqq_marlin_gemm(
g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(b_weight.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_w_q,
marlin_s,
marlin_zp,
g_idx,
g_idx_sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[0],
a_input.shape[1],
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=True,
)
@@ -621,23 +557,22 @@ def test_marlin_gemm_subset_input():
b_weight, quant_type, group_size, False)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=True,
has_zp=False,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,