[Kernel] fp4 marlin kernel (#17687)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-11 10:58:49 +08:00
committed by GitHub
parent ca66a1674c
commit d74e5f37bc
21 changed files with 1216 additions and 331 deletions

View File

@@ -33,7 +33,7 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(
has_zp: bool,
has_zp: Optional[bool] = None,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
@@ -45,6 +45,16 @@ def query_marlin_supported_quant_types(
if device_capability < 80:
return []
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
# - has_zp is None: both
if has_zp is None:
types0 = query_marlin_supported_quant_types(False, include_fp_type,
device_capability)
types1 = query_marlin_supported_quant_types(True, include_fp_type,
device_capability)
return types0 + types1
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4]
@@ -52,7 +62,7 @@ def query_marlin_supported_quant_types(
# GPTQ style, unsigned + symmetric bias
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn]
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
return res
@@ -394,6 +404,7 @@ def apply_gptq_marlin_linear(
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,
@@ -439,6 +450,7 @@ def apply_awq_marlin_linear(
None,
weight,
weight_scale,
None,
weight_zp,
g_idx,
g_idx_sort_indices,