[Kernel/Quant] Remove the original marlin format and qqq (#23204)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -13,11 +13,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
@@ -31,8 +27,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_weights)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
|
||||
marlin_qqq_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import scalar_types
|
||||
@@ -449,68 +443,6 @@ def test_hqq_marlin_gemm(
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_marlin_qqq_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
):
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize activations
|
||||
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
|
||||
torch.float)
|
||||
q_a = (a_input / s_a).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
|
||||
# Quantize weights
|
||||
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
|
||||
marlin_qqq_quantize(b_weight, num_bits, group_size)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.marlin_qqq_gemm,
|
||||
(q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
|
||||
marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.marlin_qqq_gemm(
|
||||
q_a,
|
||||
marlin_qqq_q_w,
|
||||
s_a,
|
||||
marlin_qqq_s_channel,
|
||||
marlin_qqq_s_group,
|
||||
workspace.scratch,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_subset_input():
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
@@ -602,18 +534,3 @@ def test_marlin_gemm_with_bias(size_m):
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_opcheck():
|
||||
size_m = 2048
|
||||
size_n = 4096
|
||||
size_k = 4096
|
||||
a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
|
||||
w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
|
||||
s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
|
||||
wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL).scratch
|
||||
x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
torch.testing.assert_close(x, y)
|
||||
opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
|
||||
|
||||
Reference in New Issue
Block a user