[Quantization] add marlin w4a8/w8a8 check (#31061)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
This commit is contained in:
Jinzhen Lin
2025-12-21 05:58:11 +08:00
committed by GitHub
parent ae0770fa6b
commit 7c73ceb581
3 changed files with 28 additions and 0 deletions

View File

@@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
@@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(