[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -9,6 +9,11 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
@@ -286,10 +291,10 @@ def get_scale_perms():
|
||||
|
||||
|
||||
def marlin_permute_scales(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
|
||||
) -> torch.Tensor:
|
||||
scale_perm, scale_perm_single = get_scale_perms()
|
||||
if group_size < size_k and group_size != -1:
|
||||
if group_size < size_k and group_size != -1 and not is_a_8bit:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
@@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
|
||||
return s.reshape(*origin_shape).contiguous()
|
||||
|
||||
|
||||
def marlin_act_int8_process_scales(s: torch.Tensor):
|
||||
a_scales_scale_factor = 1 / 4096 * s.max().float()
|
||||
s = s / s.max() * 4096
|
||||
s = s.round().to(torch.int16).view(s.dtype)
|
||||
return s, a_scales_scale_factor
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
group_size: int,
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
|
||||
):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty(
|
||||
@@ -319,12 +328,12 @@ def marlin_moe_permute_scales(
|
||||
)
|
||||
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit)
|
||||
return output
|
||||
|
||||
|
||||
def marlin_zero_points(
|
||||
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False
|
||||
) -> torch.Tensor:
|
||||
# Permute zero-points in a similar way to scales, but do not use the
|
||||
# "single" permutation, since zero-points are applied on every MMA
|
||||
@@ -339,7 +348,8 @@ def marlin_zero_points(
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
if not is_a_8bit:
|
||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
zp = pack_cols(zp, num_bits, size_k, size_n)
|
||||
|
||||
@@ -347,7 +357,11 @@ def marlin_zero_points(
|
||||
|
||||
|
||||
def awq_to_marlin_zero_points(
|
||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
q_zp_packed: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# AWQ zero-points are quantized and packed on the column dim.
|
||||
# In addition, the values are permuted based on dequantizer.
|
||||
@@ -366,12 +380,16 @@ def awq_to_marlin_zero_points(
|
||||
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
||||
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
|
||||
return marlin_zp
|
||||
|
||||
|
||||
def moe_awq_to_marlin_zero_points(
|
||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
q_zp_packed: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
is_a_8bit: bool = False,
|
||||
):
|
||||
num_experts = q_zp_packed.shape[0]
|
||||
output = torch.empty(
|
||||
@@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points(
|
||||
dtype=q_zp_packed.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
||||
output[e] = awq_to_marlin_zero_points(
|
||||
q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@@ -432,6 +452,48 @@ def should_use_atomic_add_reduce(
|
||||
return True
|
||||
|
||||
|
||||
_quant_fp8_method: QuantFP8 | None = None
|
||||
|
||||
|
||||
def get__quant_fp8_method() -> QuantFP8:
|
||||
global _quant_fp8_method
|
||||
if _quant_fp8_method is None:
|
||||
_quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN)
|
||||
return _quant_fp8_method
|
||||
|
||||
|
||||
def get_marlin_input_dtype(prefix):
|
||||
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
|
||||
return
|
||||
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
|
||||
return torch.int8
|
||||
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
|
||||
if not current_platform.is_device_capability(
|
||||
89
|
||||
) and not current_platform.is_device_capability(120):
|
||||
raise ValueError(
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device "
|
||||
"(It is slower than Marlin W4A16 on other devices). "
|
||||
"You can consider using W4A8-INT8 instead"
|
||||
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
|
||||
)
|
||||
|
||||
_ = get__quant_fp8_method()
|
||||
return torch.float8_e4m3fn
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype):
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if quant_dtype == torch.int8:
|
||||
return per_token_quant_int8(x)
|
||||
elif quant_dtype == torch.float8_e4m3fn:
|
||||
return get__quant_fp8_method()(x)
|
||||
else:
|
||||
raise ValueError(f"unsupported quant_dtype {quant_dtype}")
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -444,8 +506,10 @@ def apply_gptq_marlin_linear(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
input_global_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
@@ -458,12 +522,27 @@ def apply_gptq_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
assert wtype == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
assert wtype == scalar_types.uint4b8, (
|
||||
"INT8 weight + FP8 activation is not supported."
|
||||
)
|
||||
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
a_scales,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
@@ -493,8 +572,10 @@ def apply_awq_marlin_linear(
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_global_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
@@ -507,12 +588,20 @@ def apply_awq_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
a_scales,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
@@ -538,8 +627,10 @@ def apply_rtn_marlin_linear(
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_global_scale: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
input_dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
@@ -552,12 +643,20 @@ def apply_rtn_marlin_linear(
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
a_scales,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
|
||||
Reference in New Issue
Block a user