[Quantization][Deprecation] Remove Marlin 24 (#32688)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Robert Shaw
2026-01-28 07:54:59 -08:00
committed by GitHub
parent 8e5e40daf4
commit af9b69f977
20 changed files with 159 additions and 3161 deletions

View File

@@ -10,15 +10,9 @@ import itertools
import pytest
import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from tests.kernels.utils import opcheck
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL,
GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
@@ -36,15 +30,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace,
awq_marlin_quantize,
get_weight_perm,
marlin_quantize,
marlin_weights,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack,
gptq_pack,
@@ -57,9 +47,7 @@ from vllm.scalar_type import scalar_types
if current_platform.is_rocm():
pytest.skip(
"These tests require gptq_marlin_repack,"
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
"or marlin_gemm which are not supported on ROCm.",
"These tests require marlin, which is not supported on ROCm.",
allow_module_level=True,
)
@@ -71,9 +59,6 @@ USE_FP32_REDUCE_OPTS = [True]
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256]
MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
MARLIN_REPACK_NK_FACTORS = [
(4, 8),
(7, 5),
@@ -538,96 +523,6 @@ def test_marlin_gemm(
assert max_diff < 0.04
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
scratch,
quant_type,
size_m,
size_n,
size_k,
):
return ops.gptq_marlin_24_gemm(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
scratch,
quant_type,
size_m,
size_n,
size_k,
)
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
b_weight, quant_type, group_size
)
workspace_24 = MarlinWorkspace(
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
)
output_ref = torch.matmul(a_input, w_24_ref)
opcheck(
torch.ops._C.gptq_marlin_24_gemm,
(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
workspace_24.scratch,
quant_type.id,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = marlin_24_gemm_tester(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
workspace_24.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
def test_marlin_gemm_subset_input():
quant_type = scalar_types.uint4b8
group_size = 128