[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -9,6 +9,11 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -286,10 +291,10 @@ def get_scale_perms():
def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
if group_size < size_k and group_size != -1 and not is_a_8bit:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
@@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
return s.reshape(*origin_shape).contiguous()
def marlin_act_int8_process_scales(s: torch.Tensor):
a_scales_scale_factor = 1 / 4096 * s.max().float()
s = s / s.max() * 4096
s = s.round().to(torch.int16).view(s.dtype)
return s, a_scales_scale_factor
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
):
num_experts = s.shape[0]
output = torch.empty(
@@ -319,12 +328,12 @@ def marlin_moe_permute_scales(
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit)
return output
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
@@ -339,7 +348,8 @@ def marlin_zero_points(
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
if not is_a_8bit:
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
@@ -347,7 +357,11 @@ def marlin_zero_points(
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
@@ -366,12 +380,16 @@ def awq_to_marlin_zero_points(
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
return marlin_zp
def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
@@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points(
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
output[e] = awq_to_marlin_zero_points(
q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit
)
return output
@@ -432,6 +452,48 @@ def should_use_atomic_add_reduce(
return True
_quant_fp8_method: QuantFP8 | None = None
def get__quant_fp8_method() -> QuantFP8:
global _quant_fp8_method
if _quant_fp8_method is None:
_quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN)
return _quant_fp8_method
def get_marlin_input_dtype(prefix):
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
return torch.int8
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
)
_ = get__quant_fp8_method()
return torch.float8_e4m3fn
else:
return
def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype):
x = x.reshape(-1, x.shape[-1])
if quant_dtype == torch.int8:
return per_token_quant_int8(x)
elif quant_dtype == torch.float8_e4m3fn:
return get__quant_fp8_method()(x)
else:
raise ValueError(f"unsupported quant_dtype {quant_dtype}")
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
@@ -444,8 +506,10 @@ def apply_gptq_marlin_linear(
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
@@ -458,12 +522,27 @@ def apply_gptq_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
assert wtype == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert wtype == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
weight_zp,
g_idx,
@@ -493,8 +572,10 @@ def apply_awq_marlin_linear(
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
@@ -507,12 +588,20 @@ def apply_awq_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
weight_zp,
g_idx,
@@ -538,8 +627,10 @@ def apply_rtn_marlin_linear(
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
@@ -552,12 +643,20 @@ def apply_rtn_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
None,
None,