[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)

This commit is contained in:
Alexander Matveev
2024-07-21 19:41:42 -04:00
committed by GitHub
parent 25e778aa16
commit 396d92d5e0
21 changed files with 1594 additions and 276 deletions

View File

@@ -12,16 +12,18 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
marlin_permute_scales)
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS,
marlin_make_empty_g_idx, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, get_weight_perm, marlin_quantize, marlin_weights)
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
marlin_weights)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp,
sort_weights)
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
@@ -57,12 +59,12 @@ def rand_data(shape, dtype=torch.float16):
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
mnk_factors):
def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
@@ -120,12 +122,60 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Create input
b_weight = rand_data((size_k, size_n))
# Quantize
w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits,
group_size)
# Pack to AWQ format
q_w_awq = awq_pack(q_w, num_bits, size_k, size_n)
# Pack to Marlin format
weight_perm = get_weight_perm(num_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq,
size_k,
size_n,
num_bits,
)
torch.cuda.synchronize()
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
def test_marlin_gemm(
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
@@ -155,6 +205,8 @@ def test_marlin_gemm(
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, num_bits, group_size, act_order)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
@@ -162,6 +214,7 @@ def test_marlin_gemm(
a_input,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
@@ -170,6 +223,7 @@ def test_marlin_gemm(
b_weight.shape[1],
a_input.shape[1],
is_k_full,
has_zp=False,
)
output_ref = torch.matmul(a_input, w_ref)
@@ -188,7 +242,8 @@ def test_marlin_gemm(
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
@@ -301,3 +356,65 @@ def test_fp8_marlin_gemm(
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, num_bits, group_size)
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
is_k_full = True
has_zp = True
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.gptq_marlin_gemm(
a_input,
marlin_q_w,
marlin_s,
marlin_zp,
g_idx,
sort_indices,
workspace.scratch,
num_bits,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
has_zp,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04