[Kernel] optimize performance of gptq marlin kernel when n is small (#14138)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-03-08 00:53:38 +08:00
committed by GitHub
parent 58abe35455
commit d0feea31c7
6 changed files with 99 additions and 24 deletions

View File

@@ -272,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
"bool is_zp_float) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file