[Quantization][Deprecation] Remove Marlin 24 (#32688)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,15 +10,9 @@ import itertools
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from tests.kernels.utils import opcheck
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
@@ -36,15 +30,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
get_weight_perm,
|
||||
marlin_quantize,
|
||||
marlin_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack,
|
||||
gptq_pack,
|
||||
@@ -57,9 +47,7 @@ from vllm.scalar_type import scalar_types
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"These tests require gptq_marlin_repack,"
|
||||
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
|
||||
"or marlin_gemm which are not supported on ROCm.",
|
||||
"These tests require marlin, which is not supported on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
@@ -71,9 +59,6 @@ USE_FP32_REDUCE_OPTS = [True]
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 256]
|
||||
|
||||
MARLIN_24_K_CHUNKS = [128]
|
||||
MARLIN_24_N_CHUNKS = [512]
|
||||
|
||||
MARLIN_REPACK_NK_FACTORS = [
|
||||
(4, 8),
|
||||
(7, 5),
|
||||
@@ -538,96 +523,6 @@ def test_marlin_gemm(
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
# TODO: find better way to test this?
|
||||
@torch.compile(fullgraph=True)
|
||||
def marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
):
|
||||
return ops.gptq_marlin_24_gemm(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
|
||||
workspace_24 = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_24_gemm,
|
||||
(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
output = marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_subset_input():
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
|
||||
Reference in New Issue
Block a user