[Hardware][CPU] compressed-tensor INT8 W8A8 AZP support (#9344)

This commit is contained in:
Li, Jiang
2024-10-18 00:21:04 +08:00
committed by GitHub
parent 8e1cddcd44
commit 5eda21e773
7 changed files with 452 additions and 96 deletions

View File

@@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b_scales,
const c10::optional<torch::Tensor>& bias);
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const torch::Tensor& azp_adj,
const c10::optional<torch::Tensor>& azp,
const c10::optional<torch::Tensor>& bias);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif
}