[Misc][Kernel]: Add GPTQAllSpark Quantization (#12931)

This commit is contained in:
YajieWang
2025-03-01 14:30:59 +08:00
committed by GitHub
parent 6a84164add
commit 6a92ff93e1
12 changed files with 2005 additions and 4 deletions

View File

@@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
"Tensor? b_zeros, "
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
"Tensor!? b_zeros_reorder, "
"int K, int N, int N_32align) -> ()");
// conditionally compiled so impl in source file
// AllSpark quantization ops
ops.def(
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
// conditionally compiled so impl in source file
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {