[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"Tensor? b_bias_or_none,"
|
||||
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"Tensor? b_bias_or_none,Tensor b_scales, "
|
||||
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
|
||||
"Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
@@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// gptq_marlin repack from GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
||||
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
||||
"SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// awq_marlin repack from AWQ.
|
||||
ops.def(
|
||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||
"SymInt size_n, int num_bits) -> Tensor");
|
||||
"SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// preprocess W-int4A-fp8 weight for marlin kernel
|
||||
ops.def(
|
||||
"marlin_int4_fp8_preprocess(Tensor qweight, "
|
||||
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
|
||||
Reference in New Issue
Block a user