[CPU] Refactor CPU WNA16 (#28826)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -103,6 +103,13 @@ void cpu_attention_with_kv_cache(
|
||||
// Note: just for avoiding importing errors
|
||||
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
|
||||
|
||||
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
|
||||
torch::Tensor& output, const torch::Tensor& scales,
|
||||
const std::optional<torch::Tensor>& zeros,
|
||||
const std::optional<torch::Tensor>& g_idx,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
const int64_t pack_factor, const std::string& isa_hint);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@@ -283,6 +290,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
|
||||
|
||||
// WNA16
|
||||
#if defined(__AVX512F__)
|
||||
ops.def(
|
||||
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
|
||||
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
|
||||
"pack_factor, str isa_hint) -> ()");
|
||||
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
|
||||
Reference in New Issue
Block a user