[Perf] Use dummy M for weight prepacking on x86 (#35890)

Signed-off-by: Li, Tianmu <tianmu.li@intel.com>
This commit is contained in:
Tianmu Li
2026-03-04 20:56:49 -08:00
committed by GitHub
parent 0a12cea25f
commit 8e7820131e

View File

@@ -237,13 +237,10 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
}; };
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_}); {b_k_stride_, b_n_stride_});
#ifdef __aarch64__
// dummy M size for prepacking weights // dummy M size for prepacking weights
// Prepacking weights improves performance and avoid runtime reorders // Prepacking weights improves performance and avoid runtime reorders
constexpr dnnl_dim_t kProbeM = 128; constexpr dnnl_dim_t kProbeM = 128;
#else
constexpr dnnl_dim_t kProbeM = DNNL_RUNTIME_DIM_VAL;
#endif
prepack_weight(args.b_ptr, original_b_md, prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc( create_primitive_desc(
@@ -411,21 +408,19 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_}); {b_k_stride_, b_n_stride_});
// dummy M size for prepacking weights
// Prepacking weights improves performance and avoid runtime reorders
constexpr dnnl_dim_t kProbeM = 128;
prepack_weight(args.b_ptr, original_b_md, prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc( create_primitive_desc(
MSizeCacheKey{ MSizeCacheKey{// Use a concrete M so oneDNN's kernel
#ifdef VLLM_USE_ACL // selector can choose an optimally blocked
// Arm Compute Library (ACL) backend for oneDNN does // weight layout.
// not support runtime .a_m_size = kProbeM,
// dimensions, so we set M to a default value .a_m_stride = b_k_size_,
.a_m_size = 128, .use_bias = false,
.a_m_stride = b_k_size_, .bias_type = dnnl::memory::data_type::undef},
#else
.a_m_size = DNNL_RUNTIME_DIM_VAL,
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
#endif
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true) true)
.weights_desc()); .weights_desc());
init_runtime_memory_cache(args); init_runtime_memory_cache(args);