diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index 03944dc0d..14c136dcb 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -237,13 +237,10 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) }; dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, {b_k_stride_, b_n_stride_}); -#ifdef __aarch64__ + // dummy M size for prepacking weights // Prepacking weights improves performance and avoid runtime reorders constexpr dnnl_dim_t kProbeM = 128; -#else - constexpr dnnl_dim_t kProbeM = DNNL_RUNTIME_DIM_VAL; -#endif prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( @@ -411,21 +408,19 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, {b_k_stride_, b_n_stride_}); + // dummy M size for prepacking weights + // Prepacking weights improves performance and avoid runtime reorders + constexpr dnnl_dim_t kProbeM = 128; + prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( - MSizeCacheKey{ -#ifdef VLLM_USE_ACL - // Arm Compute Library (ACL) backend for oneDNN does - // not support runtime - // dimensions, so we set M to a default value - .a_m_size = 128, - .a_m_stride = b_k_size_, -#else - .a_m_size = DNNL_RUNTIME_DIM_VAL, - .a_m_stride = DNNL_RUNTIME_DIM_VAL, -#endif - .use_bias = false, - .bias_type = dnnl::memory::data_type::undef}, + MSizeCacheKey{// Use a concrete M so oneDNN's kernel + // selector can choose an optimally blocked + // weight layout. + .a_m_size = kProbeM, + .a_m_stride = b_k_size_, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, true) .weights_desc()); init_runtime_memory_cache(args);