[Fix] [CPU Backend] : Prepack weights for w8a8 oneDNN matmul (#33901)
Signed-off-by: nikhil-arm <nikhil.gupta2@arm.com>
This commit is contained in:
@@ -237,12 +237,20 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
||||
};
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
#ifdef __aarch64__
|
||||
// dummy M size for prepacking weights
|
||||
// Prepacking weights improves performance and avoid runtime reorders
|
||||
constexpr dnnl_dim_t kProbeM = 128;
|
||||
#else
|
||||
constexpr dnnl_dim_t kProbeM = DNNL_RUNTIME_DIM_VAL;
|
||||
#endif
|
||||
|
||||
prepack_weight(args.b_ptr, original_b_md,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
MSizeCacheKey{.a_m_size = kProbeM,
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
true)
|
||||
/*first_time=*/true)
|
||||
.weights_desc());
|
||||
init_runtime_memory_cache(args);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user