[Kernel] fp4 marlin kernel (#17687)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -20,6 +20,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
@@ -190,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
|
||||
@pytest.mark.parametrize(
|
||||
"group_size",
|
||||
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@@ -210,6 +213,7 @@ def test_gptq_marlin_gemm(
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
@@ -220,6 +224,8 @@ def test_gptq_marlin_gemm(
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
@@ -227,7 +233,15 @@ def test_gptq_marlin_gemm(
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
if quant_type == scalar_types.float8_e4m3fn:
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size != 16 or act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
||||
b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
@@ -236,26 +250,39 @@ def test_gptq_marlin_gemm(
|
||||
b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
elif has_zp:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace, quant_type.id, a_input.shape[0], b_weight.shape[1],
|
||||
a_input.shape[1], is_k_full, use_atomic_add, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx,
|
||||
sort_indices, workspace, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
@@ -339,67 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_awq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size)
|
||||
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
is_k_full = True
|
||||
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@@ -452,6 +418,7 @@ def test_hqq_marlin_gemm(
|
||||
None,
|
||||
marlin_w_q,
|
||||
marlin_s,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
@@ -564,6 +531,7 @@ def test_marlin_gemm_subset_input():
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
|
||||
Reference in New Issue
Block a user