[Quantization] add marlin w4a8/w8a8 check (#31061)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
This commit is contained in:
@@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
assert quant_type == scalar_types.uint4, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
assert quant_type == scalar_types.uint4, (
|
||||
"INT8 weight + FP8 activation is not supported."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
@@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
assert quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
assert quant_type == scalar_types.uint4b8, (
|
||||
"INT8 weight + FP8 activation is not supported."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
|
||||
Reference in New Issue
Block a user