[Hardware][Power] Enable compressed tensor W8A8 INT8 quantization for POWER (#17153)

Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Akash kaothalkar
2025-05-08 11:05:03 +05:30
committed by GitHub
parent 5a499e70d5
commit e515668edf
4 changed files with 669 additions and 5 deletions

View File

@@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& bias);
#if defined(__powerpc64__)
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const std::optional<torch::Tensor>& bias);
#endif
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
@@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#elif defined(__powerpc64__)
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
// W8A8 GEMM, supporting symmetric quantization.
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif
// SHM CCL