[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
@@ -73,12 +74,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
act_order, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
@@ -112,6 +110,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
q_w_gptq,
|
||||
@@ -137,12 +138,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
@@ -165,6 +163,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
|
||||
opcheck(torch.ops._C.awq_marlin_repack,
|
||||
(q_w_awq, size_k, size_n, quant_type.size_bits))
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
q_w_awq,
|
||||
@@ -204,9 +205,6 @@ def test_gptq_marlin_gemm(
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
@@ -224,6 +222,13 @@ def test_gptq_marlin_gemm(
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
|
||||
a_input.shape[1], is_k_full, False, use_fp32_reduce),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_q_w,
|
||||
@@ -245,7 +250,6 @@ def test_gptq_marlin_gemm(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
@@ -265,9 +269,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
@@ -279,6 +280,12 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
||||
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
||||
workspace_24.scratch, quant_type, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_24_gemm(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
@@ -294,7 +301,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
@@ -321,9 +327,6 @@ def test_fp8_marlin_gemm(
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
@@ -353,6 +356,10 @@ def test_fp8_marlin_gemm(
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.fp8_marlin_gemm,
|
||||
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
|
||||
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=a_input,
|
||||
b_q_weight=marlin_qweight,
|
||||
@@ -368,7 +375,6 @@ def test_fp8_marlin_gemm(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
@@ -396,9 +402,6 @@ def test_awq_marlin_gemm(
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
@@ -434,7 +437,6 @@ def test_awq_marlin_gemm(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
@@ -460,9 +462,6 @@ def test_marlin_qqq_gemm(
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
@@ -479,6 +478,11 @@ def test_marlin_qqq_gemm(
|
||||
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.marlin_qqq_gemm,
|
||||
(q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
|
||||
marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.marlin_qqq_gemm(
|
||||
q_a,
|
||||
marlin_qqq_q_w,
|
||||
@@ -495,6 +499,5 @@ def test_marlin_qqq_gemm(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
Reference in New Issue
Block a user