Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -11,24 +12,44 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like)
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||
marlin_weights)
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
get_weight_perm,
|
||||
marlin_quantize,
|
||||
marlin_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
awq_pack,
|
||||
gptq_pack,
|
||||
gptq_quantize_weights,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
@@ -56,24 +77,27 @@ DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
act_order, mnk_factors):
|
||||
def test_gptq_marlin_repack(
|
||||
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
@@ -96,7 +120,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
|
||||
# Pack to GPTQ format
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
@@ -109,11 +134,14 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
@@ -128,16 +156,16 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
@@ -152,21 +180,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize
|
||||
w_ref, q_w, s, zp = quantize_weights(b_weight,
|
||||
quant_type,
|
||||
group_size,
|
||||
zero_points=True)
|
||||
w_ref, q_w, s, zp = quantize_weights(
|
||||
b_weight, quant_type, group_size, zero_points=True
|
||||
)
|
||||
|
||||
# Pack to AWQ format
|
||||
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.awq_marlin_repack,
|
||||
(q_w_awq, size_k, size_n, quant_type.size_bits))
|
||||
opcheck(
|
||||
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
@@ -180,23 +209,34 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
|
||||
@pytest.mark.parametrize(
|
||||
"group_size",
|
||||
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
|
||||
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors, act_order, is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, dtype):
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
@@ -225,11 +265,13 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
return
|
||||
|
||||
if group_size == 16:
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
|
||||
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
|
||||
b_weight.T, group_size
|
||||
)
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s = \
|
||||
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
|
||||
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
|
||||
b_weight.T, group_size
|
||||
)
|
||||
marlin_s2 = None
|
||||
|
||||
g_idx = None
|
||||
@@ -240,8 +282,7 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||
b_weight.T, group_size)
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
@@ -250,7 +291,8 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size)
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
@@ -258,18 +300,37 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
|
||||
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
False,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
@@ -302,23 +363,40 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
# TODO: find better way to test this?
|
||||
@torch.compile(fullgraph=True)
|
||||
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m, size_n,
|
||||
size_k):
|
||||
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m,
|
||||
size_n, size_k)
|
||||
def marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
):
|
||||
return ops.gptq_marlin_24_gemm(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
@@ -328,19 +406,31 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
|
||||
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
workspace_24 = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
||||
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
||||
workspace_24.scratch, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_24_gemm,
|
||||
(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
output = marlin_24_gemm_tester(
|
||||
a_input,
|
||||
@@ -361,8 +451,10 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
||||
@@ -386,22 +478,22 @@ def test_hqq_marlin_gemm(
|
||||
a_input = rand_data((size_m, size_k))
|
||||
dev = a_input.device
|
||||
|
||||
b_weight = torch.randint(0,
|
||||
10, (size_n, size_k),
|
||||
dtype=torch.uint8,
|
||||
device=dev)
|
||||
b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
|
||||
scale = rand_data((size_n, size_k // group_size))
|
||||
zero = rand_data((size_n, size_k // group_size))
|
||||
|
||||
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
|
||||
4).to(dev)
|
||||
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
|
||||
dev
|
||||
)
|
||||
marlin_s = marlin_permute_scales(
|
||||
scale.transpose(1, 0), size_k, size_n, group_size
|
||||
).to(dev)
|
||||
marlin_zp = marlin_permute_scales(
|
||||
zero.transpose(1, 0), size_k, size_n, group_size
|
||||
).to(dev)
|
||||
|
||||
g_idx = marlin_make_empty_g_idx(dev)
|
||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
@@ -433,8 +525,7 @@ def test_hqq_marlin_gemm(
|
||||
s_flat = scale.reshape(-1, 1)
|
||||
dequant = (b_flat - zp_flat) * s_flat
|
||||
|
||||
output_ref = torch.matmul(a_input,
|
||||
dequant.reshape(b_weight.shape).transpose(1, 0))
|
||||
output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -451,11 +542,12 @@ def test_marlin_gemm_subset_input():
|
||||
big_m = size_m * 2
|
||||
big_k = size_k * 2
|
||||
|
||||
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
|
||||
a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False)
|
||||
b_weight, quant_type, group_size, False
|
||||
)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
@@ -497,12 +589,13 @@ def test_marlin_gemm_with_bias(size_m):
|
||||
size_k, size_n = 1024, 2048
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
b_bias = rand_data((size_n, )) * 10
|
||||
b_bias = rand_data((size_n,)) * 10
|
||||
|
||||
marlin_bias = marlin_permute_bias(b_bias)
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False)
|
||||
b_weight, quant_type, group_size, False
|
||||
)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
Reference in New Issue
Block a user