From caad9f1e01ee04e4f5912d0287031ea3a850f6dc Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Mon, 9 Feb 2026 10:04:41 +0000 Subject: [PATCH] [Fix] [CPU Backend] : Prepack weights for w8a8 oneDNN matmul (#33901) Signed-off-by: nikhil-arm --- csrc/cpu/dnnl_helper.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index e337e10e1..03944dc0d 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -237,12 +237,20 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) }; dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, {b_k_stride_, b_n_stride_}); +#ifdef __aarch64__ + // dummy M size for prepacking weights + // Prepacking weights improves performance and avoid runtime reorders + constexpr dnnl_dim_t kProbeM = 128; +#else + constexpr dnnl_dim_t kProbeM = DNNL_RUNTIME_DIM_VAL; +#endif + prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( - MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + MSizeCacheKey{.a_m_size = kProbeM, .use_bias = false, .bias_type = dnnl::memory::data_type::undef}, - true) + /*first_time=*/true) .weights_desc()); init_runtime_memory_cache(args); }