[Misc][Kernel]: Add GPTQAllSpark Quantization (#12931)
This commit is contained in:
@@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||
&dynamic_scaled_int8_quant);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||
"Tensor? b_zeros, "
|
||||
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
|
||||
"Tensor!? b_zeros_reorder, "
|
||||
"int K, int N, int N_32align) -> ()");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// AllSpark quantization ops
|
||||
ops.def(
|
||||
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
|
||||
"Tensor? b_qzeros, "
|
||||
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
|
||||
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
||||
// conditionally compiled so impl in source file
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
Reference in New Issue
Block a user