[CPU] Refactor CPU WNA16 (#28826)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-11-19 10:32:00 +08:00
committed by GitHub
parent 40b6b38f2c
commit 20852c8f4c
22 changed files with 1656 additions and 78 deletions

View File

@@ -103,6 +103,13 @@ void cpu_attention_with_kv_cache(
// Note: just for avoiding importing errors
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
torch::Tensor& output, const torch::Tensor& scales,
const std::optional<torch::Tensor>& zeros,
const std::optional<torch::Tensor>& g_idx,
const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@@ -283,6 +290,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
// WNA16
#if defined(__AVX512F__)
ops.def(
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
"pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {