[Kernel] optimize performance of gptq marlin kernel when n is small (#14138)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -272,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
|
||||
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
|
||||
"bool is_zp_float) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
|
||||
Reference in New Issue
Block a user