[Kernel] fp4 marlin kernel (#17687)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -33,7 +33,7 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: bool,
|
||||
has_zp: Optional[bool] = None,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
):
|
||||
@@ -45,6 +45,16 @@ def query_marlin_supported_quant_types(
|
||||
if device_capability < 80:
|
||||
return []
|
||||
|
||||
# - has_zp is True: return quant_types that has zero points
|
||||
# - has_zp is False: return quant_types that has not zero points
|
||||
# - has_zp is None: both
|
||||
if has_zp is None:
|
||||
types0 = query_marlin_supported_quant_types(False, include_fp_type,
|
||||
device_capability)
|
||||
types1 = query_marlin_supported_quant_types(True, include_fp_type,
|
||||
device_capability)
|
||||
return types0 + types1
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4]
|
||||
@@ -52,7 +62,7 @@ def query_marlin_supported_quant_types(
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
if include_fp_type:
|
||||
res += [scalar_types.float8_e4m3fn]
|
||||
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
|
||||
return res
|
||||
|
||||
|
||||
@@ -394,6 +404,7 @@ def apply_gptq_marlin_linear(
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
@@ -439,6 +450,7 @@ def apply_awq_marlin_linear(
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
|
||||
Reference in New Issue
Block a user