[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
@@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
MOE_MARLIN_QUANT_TEST_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": scalar_types.uint8b128,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
|
||||
# NVFP4
|
||||
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": [scalar_types.bfloat16],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"c_type": [scalar_types.bfloat16],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
FUSED_MOE_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 2048, 128),
|
||||
@@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases():
|
||||
m_list = [1, 123, 666]
|
||||
n_list = [128, 1024]
|
||||
k_list = [256, 2048]
|
||||
e_list = [4, 12]
|
||||
e_list = [5, 12]
|
||||
topk_list = [2, 3]
|
||||
ep_size_list = [1, 4]
|
||||
dtype_list = [torch.bfloat16]
|
||||
group_size_list = [-1, 32, 128]
|
||||
act_order_list = [True, False]
|
||||
quant_type_list = [
|
||||
scalar_types.float4_e2m1f,
|
||||
scalar_types.float8_e4m3fn,
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
is_k_full_list = [True, False]
|
||||
|
||||
all_combinations = itertools.product(
|
||||
MOE_MARLIN_QUANT_TEST_CONFIGS,
|
||||
m_list,
|
||||
n_list,
|
||||
k_list,
|
||||
e_list,
|
||||
topk_list,
|
||||
ep_size_list,
|
||||
dtype_list,
|
||||
group_size_list,
|
||||
act_order_list,
|
||||
quant_type_list,
|
||||
is_k_full_list,
|
||||
)
|
||||
|
||||
def is_invalid(
|
||||
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
ep_size,
|
||||
act_order,
|
||||
is_k_full,
|
||||
):
|
||||
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
|
||||
return False
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size not in [16, 32]:
|
||||
return False
|
||||
if dtype == torch.float16 and group_size == 32:
|
||||
return False
|
||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
if group_size > 0 and k % group_size != 0:
|
||||
return False
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size in (-1, k, n):
|
||||
return False
|
||||
if quant_type not in [scalar_types.uint4b8]:
|
||||
return False
|
||||
elif not is_k_full:
|
||||
if act_order and group_size in [-1, k, n]:
|
||||
return False
|
||||
if group_size in [k, n]:
|
||||
return False
|
||||
if not act_order and is_k_full:
|
||||
return False
|
||||
|
||||
return True
|
||||
return a_type.size_bits < 16 or a_type is c_type
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
if is_invalid(*case):
|
||||
cases.append(case)
|
||||
quant_test_config, m, n, k, _, _, _, act_order, *_ = case
|
||||
if act_order and not quant_test_config.get("support_act_order", False):
|
||||
continue
|
||||
|
||||
f16_types = [scalar_types.float16]
|
||||
inner_combinations = itertools.product(
|
||||
quant_test_config.get("a_type", f16_types),
|
||||
[quant_test_config["b_type"]],
|
||||
quant_test_config.get("c_type", f16_types),
|
||||
quant_test_config["group_blocks"],
|
||||
)
|
||||
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
):
|
||||
continue
|
||||
args = sub_case + (m, n, k) + case[4:]
|
||||
if is_invalid(*args):
|
||||
cases.append(args)
|
||||
return cases
|
||||
|
||||
|
||||
@@ -571,6 +640,7 @@ class MarlinMoEWeightData:
|
||||
qweight: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
global_scale: torch.Tensor | None
|
||||
a_scales_factor: torch.Tensor | None
|
||||
g_idx: torch.Tensor | None
|
||||
zeros: torch.Tensor | None
|
||||
sort_indices: torch.Tensor | None
|
||||
@@ -583,11 +653,20 @@ class MarlinMoEWeightData:
|
||||
group_size: int,
|
||||
act_order: bool | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
input_type: ScalarType = None,
|
||||
) -> "MarlinMoEWeightData":
|
||||
assert w.ndim == 3
|
||||
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
k = w.shape[-1]
|
||||
|
||||
if input_type == scalar_types.int8:
|
||||
input_dtype = torch.int8
|
||||
elif input_type == scalar_types.float8_e4m3fn:
|
||||
input_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
input_dtype = w.dtype
|
||||
|
||||
w_ref_l: list[torch.Tensor] = []
|
||||
qweight_l: list[torch.Tensor] = []
|
||||
scales_l: list[torch.Tensor] = []
|
||||
@@ -601,11 +680,13 @@ class MarlinMoEWeightData:
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, qweight, scales, global_scale = (
|
||||
rand_marlin_weight_nvfp4_like(w[i], group_size)
|
||||
rand_marlin_weight_nvfp4_like(
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
|
||||
w[i], group_size
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
global_scale = None
|
||||
|
||||
@@ -615,13 +696,18 @@ class MarlinMoEWeightData:
|
||||
if global_scale is not None:
|
||||
global_scale_l.append(global_scale)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
elif has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size
|
||||
w[i].transpose(1, 0),
|
||||
quant_type,
|
||||
group_size,
|
||||
input_dtype=input_dtype,
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
@@ -631,7 +717,12 @@ class MarlinMoEWeightData:
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
w[i].transpose(1, 0),
|
||||
quant_type,
|
||||
group_size,
|
||||
act_order,
|
||||
test_perm,
|
||||
input_dtype=input_dtype,
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
@@ -652,11 +743,18 @@ class MarlinMoEWeightData:
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
marlin_bias = stack_and_dev(bias_l) if bias_l else None
|
||||
|
||||
a_scales_factor = None
|
||||
if input_type == scalar_types.int8 and group_size != -1:
|
||||
a_scales_factor = 1 / 4096 * scales.max().float()
|
||||
scales = scales / scales.max() * 4096
|
||||
scales = scales.round().to(torch.int16).view(w.dtype)
|
||||
|
||||
return MarlinMoEWeightData(
|
||||
w_ref=w_ref,
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
global_scale=global_scale,
|
||||
a_scales_factor=a_scales_factor,
|
||||
g_idx=g_idx,
|
||||
zeros=zeros,
|
||||
sort_indices=sort_indices,
|
||||
@@ -666,28 +764,47 @@ class MarlinMoEWeightData:
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize(
|
||||
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
|
||||
(
|
||||
"a_type, b_type, c_type, group_blocks,"
|
||||
"m, n, k, e, topk, ep_size, act_order, is_k_full"
|
||||
),
|
||||
marlin_moe_generate_valid_test_cases(),
|
||||
)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
quant_type: ScalarType,
|
||||
is_k_full: bool,
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
ep_size,
|
||||
act_order,
|
||||
is_k_full,
|
||||
):
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.cuda.manual_seed(1)
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
|
||||
if c_type == scalar_types.float16:
|
||||
dtype = torch.float16
|
||||
elif c_type == scalar_types.bfloat16:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError("unsupported c_type")
|
||||
|
||||
if a_type == scalar_types.int8:
|
||||
a_dtype = torch.int8
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
a_dtype = dtype
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
@@ -700,11 +817,19 @@ def test_fused_marlin_moe(
|
||||
e_map = None
|
||||
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
w=w1,
|
||||
quant_type=b_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
input_type=a_type,
|
||||
)
|
||||
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
w=w2,
|
||||
quant_type=b_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
input_type=a_type,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
@@ -712,8 +837,18 @@ def test_fused_marlin_moe(
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(
|
||||
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
torch_output = torch_experts(
|
||||
a,
|
||||
w1_data.w_ref,
|
||||
w2_data.w_ref,
|
||||
topk_weight=topk_weight,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
quant_dtype=a_dtype,
|
||||
per_act_token_quant=True,
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
@@ -733,15 +868,18 @@ def test_fused_marlin_moe(
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
input_global_scale1=w1_data.a_scales_factor,
|
||||
input_global_scale2=w2_data.a_scales_factor,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
input_dtype=a_dtype,
|
||||
quant_type_id=b_type.id,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
@@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like,
|
||||
)
|
||||
@@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
@@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512]
|
||||
|
||||
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
||||
|
||||
MARLIN_REPACK_NK_FACTORS = [
|
||||
(4, 8),
|
||||
(7, 5),
|
||||
(13, 11),
|
||||
]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
@@ -74,6 +84,64 @@ MNK_FACTORS = [
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": scalar_types.uint8b128,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
|
||||
# NVFP4
|
||||
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": [scalar_types.bfloat16],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"c_type": [scalar_types.bfloat16],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
@@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
def test_marlin_int4_fp8_preprocess_without_zp():
|
||||
qweight_unpacked = torch.randint(
|
||||
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
||||
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
||||
|
||||
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
|
||||
|
||||
torch_res = torch.where(
|
||||
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
|
||||
)
|
||||
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
||||
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
||||
|
||||
assert (cuda_res == torch_res).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
def test_marlin_int4_fp8_preprocess_awq():
|
||||
group_size = 128
|
||||
|
||||
qweight_unpacked = torch.randint(
|
||||
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qzeros_unpacked = torch.randint(
|
||||
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
||||
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
||||
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
|
||||
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
|
||||
|
||||
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
|
||||
|
||||
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
|
||||
torch_res = qweight_unpacked - repeated_zp
|
||||
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
|
||||
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
||||
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
||||
|
||||
assert (cuda_res == torch_res).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
@@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16):
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
||||
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
||||
def test_gptq_marlin_repack(
|
||||
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
|
||||
k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
n_factor, k_factor = nk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
group_size = 128
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
@@ -109,6 +230,8 @@ def test_gptq_marlin_repack(
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if is_a_8bit:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
@@ -133,23 +256,19 @@ def test_gptq_marlin_repack(
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
q_w_gptq,
|
||||
sort_indices,
|
||||
size_k,
|
||||
size_n,
|
||||
quant_type.size_bits,
|
||||
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -163,18 +282,15 @@ def test_gptq_marlin_repack(
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
||||
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
|
||||
n_factor, k_factor = nk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
group_size = 128
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
@@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors
|
||||
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
|
||||
torch.ops._C.awq_marlin_repack,
|
||||
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
q_w_awq,
|
||||
size_k,
|
||||
size_n,
|
||||
quant_type.size_bits,
|
||||
q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
def marlin_generate_valid_test_cases():
|
||||
all_combinations = itertools.product(
|
||||
DENSE_MARLIN_QUANT_TEST_CONFIGS,
|
||||
MNK_FACTORS,
|
||||
MARLIN_N_CHUNKS,
|
||||
MARLIN_K_CHUNKS,
|
||||
ACT_ORDER_OPTS,
|
||||
K_FULL_OPTS,
|
||||
USE_ATOMIC_ADD_OPTS,
|
||||
USE_FP32_REDUCE_OPTS,
|
||||
)
|
||||
|
||||
def is_invalid(
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
if use_atomic_add:
|
||||
if use_fp32_reduce:
|
||||
return False
|
||||
if (
|
||||
c_type == scalar_types.bfloat16
|
||||
and torch.cuda.get_device_capability()[0] < 9
|
||||
):
|
||||
return False
|
||||
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
if group_size > 0 and size_k % group_size != 0:
|
||||
return False
|
||||
|
||||
if act_order and group_size in [-1, size_k]:
|
||||
return False
|
||||
if group_size == size_k:
|
||||
return False
|
||||
if not act_order and is_k_full:
|
||||
return False
|
||||
|
||||
return a_type.size_bits < 16 or a_type is c_type
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
|
||||
size_m = mnk_factors[0]
|
||||
size_n = mnk_factors[1] * n_chunk
|
||||
size_k = mnk_factors[2] * k_chunk
|
||||
|
||||
if act_order and not quant_test_config.get("support_act_order", False):
|
||||
continue
|
||||
|
||||
f16_types = [scalar_types.float16, scalar_types.bfloat16]
|
||||
inner_combinations = itertools.product(
|
||||
quant_test_config.get("a_type", f16_types),
|
||||
[quant_test_config["b_type"]],
|
||||
quant_test_config.get("c_type", f16_types),
|
||||
quant_test_config["group_blocks"],
|
||||
)
|
||||
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
):
|
||||
continue
|
||||
args = sub_case + (size_m, size_n, size_k) + case[4:]
|
||||
if is_invalid(*args):
|
||||
cases.append(args)
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
|
||||
@pytest.mark.parametrize(
|
||||
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
(
|
||||
"a_type, b_type, c_type, group_blocks,"
|
||||
"size_m, size_n, size_k, act_order, is_k_full,"
|
||||
"use_atomic_add, use_fp32_reduce"
|
||||
),
|
||||
marlin_generate_valid_test_cases(),
|
||||
)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
if c_type == scalar_types.float16:
|
||||
dtype = torch.float16
|
||||
elif c_type == scalar_types.bfloat16:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError("unsupported c_type")
|
||||
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
if a_type == scalar_types.int8:
|
||||
a_dtype = torch.int8
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
a_dtype = dtype
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype)
|
||||
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size not in [16, 32] or act_order:
|
||||
return
|
||||
if group_size == 32 and dtype == torch.float16:
|
||||
return
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
if b_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
|
||||
b_weight.T, group_size
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
|
||||
b_weight.T, group_size
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
marlin_s2 = None
|
||||
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128]:
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
|
||||
elif b_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
elif has_zp:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
b_weight, b_type, group_size, input_dtype=a_dtype
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
b_weight, b_type, group_size, act_order, input_dtype=a_dtype
|
||||
)
|
||||
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
False,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
if a_type == scalar_types.int8:
|
||||
a_input, a_scales = per_token_quant_int8(a_input)
|
||||
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
||||
a_input_ref = a_input_ref.to(dtype)
|
||||
|
||||
if group_size != -1:
|
||||
a_scales = a_scales / 4096 * marlin_s.max()
|
||||
a_scales = a_scales.float()
|
||||
marlin_s = marlin_s / marlin_s.max() * 4096
|
||||
marlin_s = marlin_s.round().to(torch.int16).view(dtype)
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
|
||||
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
||||
a_input_ref = a_input_ref.to(dtype)
|
||||
else:
|
||||
assert a_type.size_bits == 16
|
||||
a_input_ref = a_input
|
||||
a_scales = None
|
||||
|
||||
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
output,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
a_scales,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
b_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
@@ -352,12 +527,9 @@ def test_gptq_marlin_gemm(
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
output_ref = torch.matmul(a_input_ref, w_ref)
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@@ -507,6 +679,7 @@ def test_hqq_marlin_gemm(
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
@@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input():
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
@@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m):
|
||||
marlin_bias,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
|
||||
@@ -846,6 +846,13 @@ def torch_experts(
|
||||
or (expert_map is not None and global_num_experts == expert_map.shape[0])
|
||||
)
|
||||
|
||||
if quant_dtype in [torch.float16, torch.bfloat16]:
|
||||
quant_dtype = None
|
||||
quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
|
||||
if quant_input_only:
|
||||
assert a1_scale is None and a2_scale is None
|
||||
assert per_act_token_quant
|
||||
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
@@ -863,6 +870,9 @@ def torch_experts(
|
||||
a, a1_scale, quant_dtype, per_act_token_quant, block_shape
|
||||
)
|
||||
|
||||
if quant_input_only:
|
||||
a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
topk_ids = topk_ids.view(-1)
|
||||
@@ -882,6 +892,14 @@ def torch_experts(
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
if b_bias2 is not None:
|
||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
|
||||
elif quant_input_only:
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, tmp2_scale = moe_kernel_quantize_input(
|
||||
tmp2, None, quant_dtype, per_act_token_quant
|
||||
)
|
||||
tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
elif block_shape is not None:
|
||||
# block quantized
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user