[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -5,6 +5,8 @@
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
"""
import itertools
import pytest
import torch
@@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_permute_bias,
@@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
query_marlin_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like,
)
@@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights,
sort_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True]
@@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512]
HQQ_SUPPORTED_GROUP_SIZES = [64]
MARLIN_REPACK_NK_FACTORS = [
(4, 8),
(7, 5),
(13, 11),
]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
@@ -74,6 +84,64 @@ MNK_FACTORS = [
DTYPES = [torch.float16, torch.bfloat16]
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
# GPTQ-INT4
{
"b_type": scalar_types.uint4b8,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": scalar_types.uint8b128,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# FP8
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
# NVFP4
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
# MXFP4
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
@@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
torch_res = torch.where(
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
)
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
group_size = 128
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qzeros_unpacked = torch.randint(
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
torch_res = qweight_unpacked - repeated_zp
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
@@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
):
m_factor, n_factor, k_factor = mnk_factors
n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
group_size = 128
# Filter act_order
if act_order:
@@ -109,6 +230,8 @@ def test_gptq_marlin_repack(
return
if group_size == size_k:
return
if is_a_8bit:
return
# Normalize group_size
if group_size == -1:
@@ -133,23 +256,19 @@ def test_gptq_marlin_repack(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
)
opcheck(
torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq,
sort_indices,
size_k,
size_n,
quant_type.size_bits,
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
)
torch.cuda.synchronize()
@@ -163,18 +282,15 @@ def test_gptq_marlin_repack(
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
group_size = 128
# Create input
b_weight = rand_data((size_k, size_n))
@@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
)
opcheck(
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq,
size_k,
size_n,
quant_type.size_bits,
q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
)
torch.cuda.synchronize()
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
def marlin_generate_valid_test_cases():
all_combinations = itertools.product(
DENSE_MARLIN_QUANT_TEST_CONFIGS,
MNK_FACTORS,
MARLIN_N_CHUNKS,
MARLIN_K_CHUNKS,
ACT_ORDER_OPTS,
K_FULL_OPTS,
USE_ATOMIC_ADD_OPTS,
USE_FP32_REDUCE_OPTS,
)
def is_invalid(
a_type,
b_type,
c_type,
group_blocks,
size_m,
size_n,
size_k,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
if use_atomic_add:
if use_fp32_reduce:
return False
if (
c_type == scalar_types.bfloat16
and torch.cuda.get_device_capability()[0] < 9
):
return False
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if group_size > 0 and size_k % group_size != 0:
return False
if act_order and group_size in [-1, size_k]:
return False
if group_size == size_k:
return False
if not act_order and is_k_full:
return False
return a_type.size_bits < 16 or a_type is c_type
cases = []
for case in all_combinations:
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
size_m = mnk_factors[0]
size_n = mnk_factors[1] * n_chunk
size_k = mnk_factors[2] * k_chunk
if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16, scalar_types.bfloat16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (size_m, size_n, size_k) + case[4:]
if is_invalid(*args):
cases.append(args)
return cases
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
(
"a_type, b_type, c_type, group_blocks,"
"size_m, size_n, size_k, act_order, is_k_full,"
"use_atomic_add, use_fp32_reduce"
),
marlin_generate_valid_test_cases(),
)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
a_type,
b_type,
c_type,
group_blocks,
size_m,
size_n,
size_k,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
if c_type == scalar_types.float16:
dtype = torch.float16
elif c_type == scalar_types.bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("unsupported c_type")
if size_k % group_size != 0:
return
if a_type == scalar_types.int8:
a_dtype = torch.int8
elif a_type == scalar_types.float8_e4m3fn:
a_dtype = torch.float8_e4m3fn
else:
a_dtype = dtype
a_input = rand_data((size_m, size_k), dtype)
b_weight = rand_data((size_k, size_n), dtype)
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32] or act_order:
return
if group_size == 32 and dtype == torch.float16:
return
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
if b_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
b_weight.T, group_size
b_weight.T, group_size, input_dtype=a_dtype
)
else:
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
b_weight.T, group_size
b_weight.T, group_size, input_dtype=a_dtype
)
marlin_s2 = None
g_idx = None
sort_indices = None
marlin_zp = None
elif quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
elif b_type == scalar_types.float8_e4m3fn:
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
b_weight.T, group_size, input_dtype=a_dtype
)
g_idx = None
sort_indices = None
marlin_zp = None
marlin_s2 = None
elif has_zp:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size
b_weight, b_type, group_size, input_dtype=a_dtype
)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order
b_weight, b_type, group_size, act_order, input_dtype=a_dtype
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(
a_input,
None,
marlin_q_w,
None,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type.id,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
use_atomic_add,
use_fp32_reduce,
False,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
if a_type == scalar_types.int8:
a_input, a_scales = per_token_quant_int8(a_input)
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
a_input_ref = a_input_ref.to(dtype)
if group_size != -1:
a_scales = a_scales / 4096 * marlin_s.max()
a_scales = a_scales.float()
marlin_s = marlin_s / marlin_s.max() * 4096
marlin_s = marlin_s.round().to(torch.int16).view(dtype)
elif a_type == scalar_types.float8_e4m3fn:
a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
a_input_ref = a_input_ref.to(dtype)
else:
assert a_type.size_bits == 16
a_input_ref = a_input
a_scales = None
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
output,
marlin_q_w,
None,
marlin_s,
a_scales,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
b_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
@@ -352,12 +527,9 @@ def test_gptq_marlin_gemm(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
output_ref = torch.matmul(a_input_ref, w_ref)
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
@@ -507,6 +679,7 @@ def test_hqq_marlin_gemm(
None,
marlin_s,
None,
None,
marlin_zp,
g_idx,
g_idx_sort_indices,
@@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input():
None,
marlin_s,
None,
None,
marlin_zp,
g_idx,
sort_indices,
@@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m):
marlin_bias,
marlin_s,
None,
None,
marlin_zp,
g_idx,
sort_indices,