[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
@@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
MOE_MARLIN_QUANT_TEST_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": scalar_types.uint8b128,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
|
||||
# NVFP4
|
||||
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": [scalar_types.bfloat16],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"c_type": [scalar_types.bfloat16],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
FUSED_MOE_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 2048, 128),
|
||||
@@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases():
|
||||
m_list = [1, 123, 666]
|
||||
n_list = [128, 1024]
|
||||
k_list = [256, 2048]
|
||||
e_list = [4, 12]
|
||||
e_list = [5, 12]
|
||||
topk_list = [2, 3]
|
||||
ep_size_list = [1, 4]
|
||||
dtype_list = [torch.bfloat16]
|
||||
group_size_list = [-1, 32, 128]
|
||||
act_order_list = [True, False]
|
||||
quant_type_list = [
|
||||
scalar_types.float4_e2m1f,
|
||||
scalar_types.float8_e4m3fn,
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
is_k_full_list = [True, False]
|
||||
|
||||
all_combinations = itertools.product(
|
||||
MOE_MARLIN_QUANT_TEST_CONFIGS,
|
||||
m_list,
|
||||
n_list,
|
||||
k_list,
|
||||
e_list,
|
||||
topk_list,
|
||||
ep_size_list,
|
||||
dtype_list,
|
||||
group_size_list,
|
||||
act_order_list,
|
||||
quant_type_list,
|
||||
is_k_full_list,
|
||||
)
|
||||
|
||||
def is_invalid(
|
||||
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
ep_size,
|
||||
act_order,
|
||||
is_k_full,
|
||||
):
|
||||
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
|
||||
return False
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size not in [16, 32]:
|
||||
return False
|
||||
if dtype == torch.float16 and group_size == 32:
|
||||
return False
|
||||
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
if group_size > 0 and k % group_size != 0:
|
||||
return False
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size in (-1, k, n):
|
||||
return False
|
||||
if quant_type not in [scalar_types.uint4b8]:
|
||||
return False
|
||||
elif not is_k_full:
|
||||
if act_order and group_size in [-1, k, n]:
|
||||
return False
|
||||
if group_size in [k, n]:
|
||||
return False
|
||||
if not act_order and is_k_full:
|
||||
return False
|
||||
|
||||
return True
|
||||
return a_type.size_bits < 16 or a_type is c_type
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
if is_invalid(*case):
|
||||
cases.append(case)
|
||||
quant_test_config, m, n, k, _, _, _, act_order, *_ = case
|
||||
if act_order and not quant_test_config.get("support_act_order", False):
|
||||
continue
|
||||
|
||||
f16_types = [scalar_types.float16]
|
||||
inner_combinations = itertools.product(
|
||||
quant_test_config.get("a_type", f16_types),
|
||||
[quant_test_config["b_type"]],
|
||||
quant_test_config.get("c_type", f16_types),
|
||||
quant_test_config["group_blocks"],
|
||||
)
|
||||
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
):
|
||||
continue
|
||||
args = sub_case + (m, n, k) + case[4:]
|
||||
if is_invalid(*args):
|
||||
cases.append(args)
|
||||
return cases
|
||||
|
||||
|
||||
@@ -571,6 +640,7 @@ class MarlinMoEWeightData:
|
||||
qweight: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
global_scale: torch.Tensor | None
|
||||
a_scales_factor: torch.Tensor | None
|
||||
g_idx: torch.Tensor | None
|
||||
zeros: torch.Tensor | None
|
||||
sort_indices: torch.Tensor | None
|
||||
@@ -583,11 +653,20 @@ class MarlinMoEWeightData:
|
||||
group_size: int,
|
||||
act_order: bool | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
input_type: ScalarType = None,
|
||||
) -> "MarlinMoEWeightData":
|
||||
assert w.ndim == 3
|
||||
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
k = w.shape[-1]
|
||||
|
||||
if input_type == scalar_types.int8:
|
||||
input_dtype = torch.int8
|
||||
elif input_type == scalar_types.float8_e4m3fn:
|
||||
input_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
input_dtype = w.dtype
|
||||
|
||||
w_ref_l: list[torch.Tensor] = []
|
||||
qweight_l: list[torch.Tensor] = []
|
||||
scales_l: list[torch.Tensor] = []
|
||||
@@ -601,11 +680,13 @@ class MarlinMoEWeightData:
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, qweight, scales, global_scale = (
|
||||
rand_marlin_weight_nvfp4_like(w[i], group_size)
|
||||
rand_marlin_weight_nvfp4_like(
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
|
||||
w[i], group_size
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
global_scale = None
|
||||
|
||||
@@ -615,13 +696,18 @@ class MarlinMoEWeightData:
|
||||
if global_scale is not None:
|
||||
global_scale_l.append(global_scale)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(
|
||||
w[i], group_size, input_dtype=input_dtype
|
||||
)
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
elif has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size
|
||||
w[i].transpose(1, 0),
|
||||
quant_type,
|
||||
group_size,
|
||||
input_dtype=input_dtype,
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
@@ -631,7 +717,12 @@ class MarlinMoEWeightData:
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
w[i].transpose(1, 0),
|
||||
quant_type,
|
||||
group_size,
|
||||
act_order,
|
||||
test_perm,
|
||||
input_dtype=input_dtype,
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
@@ -652,11 +743,18 @@ class MarlinMoEWeightData:
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
marlin_bias = stack_and_dev(bias_l) if bias_l else None
|
||||
|
||||
a_scales_factor = None
|
||||
if input_type == scalar_types.int8 and group_size != -1:
|
||||
a_scales_factor = 1 / 4096 * scales.max().float()
|
||||
scales = scales / scales.max() * 4096
|
||||
scales = scales.round().to(torch.int16).view(w.dtype)
|
||||
|
||||
return MarlinMoEWeightData(
|
||||
w_ref=w_ref,
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
global_scale=global_scale,
|
||||
a_scales_factor=a_scales_factor,
|
||||
g_idx=g_idx,
|
||||
zeros=zeros,
|
||||
sort_indices=sort_indices,
|
||||
@@ -666,28 +764,47 @@ class MarlinMoEWeightData:
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize(
|
||||
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
|
||||
(
|
||||
"a_type, b_type, c_type, group_blocks,"
|
||||
"m, n, k, e, topk, ep_size, act_order, is_k_full"
|
||||
),
|
||||
marlin_moe_generate_valid_test_cases(),
|
||||
)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
quant_type: ScalarType,
|
||||
is_k_full: bool,
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
ep_size,
|
||||
act_order,
|
||||
is_k_full,
|
||||
):
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.cuda.manual_seed(1)
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
|
||||
if c_type == scalar_types.float16:
|
||||
dtype = torch.float16
|
||||
elif c_type == scalar_types.bfloat16:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError("unsupported c_type")
|
||||
|
||||
if a_type == scalar_types.int8:
|
||||
a_dtype = torch.int8
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
a_dtype = dtype
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
@@ -700,11 +817,19 @@ def test_fused_marlin_moe(
|
||||
e_map = None
|
||||
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
w=w1,
|
||||
quant_type=b_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
input_type=a_type,
|
||||
)
|
||||
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
w=w2,
|
||||
quant_type=b_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
input_type=a_type,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
@@ -712,8 +837,18 @@ def test_fused_marlin_moe(
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(
|
||||
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
torch_output = torch_experts(
|
||||
a,
|
||||
w1_data.w_ref,
|
||||
w2_data.w_ref,
|
||||
topk_weight=topk_weight,
|
||||
topk_ids=topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
quant_dtype=a_dtype,
|
||||
per_act_token_quant=True,
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
@@ -733,15 +868,18 @@ def test_fused_marlin_moe(
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
input_global_scale1=w1_data.a_scales_factor,
|
||||
input_global_scale2=w2_data.a_scales_factor,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
input_dtype=a_dtype,
|
||||
quant_type_id=b_type.id,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
|
||||
Reference in New Issue
Block a user