[cpu][perf] Accelerate unquantized-linear for AArch64 through oneDNN/ACL and weight prepack (#25948)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
@@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
||||
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
void* original_b_ptr, dnnl::memory::desc original_b_md,
|
||||
dnnl::memory::desc b_target_mem_desc) {
|
||||
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
||||
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
||||
{
|
||||
@@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
||||
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
||||
assert(!use_azp_);
|
||||
};
|
||||
prepack_weight(args.b_ptr,
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
prepack_weight(args.b_ptr, original_b_md,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.use_bias = false,
|
||||
@@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
|
||||
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
||||
ab_type_ == dnnl::memory::data_type::bf16 ||
|
||||
ab_type_ == dnnl::memory::data_type::f16);
|
||||
prepack_weight(args.b_ptr,
|
||||
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
|
||||
prepack_weight(args.b_ptr, original_b_md,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
MSizeCacheKey{
|
||||
#ifdef VLLM_USE_ACL
|
||||
// Arm Compute Library (ACL) backend for oneDNN does
|
||||
// not support runtime
|
||||
// dimensions, so we set M to a default value
|
||||
.a_m_size = 128,
|
||||
.a_m_stride = b_k_size_,
|
||||
#else
|
||||
.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
|
||||
#endif
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
true)
|
||||
.weights_desc());
|
||||
init_runtime_memory_cache(args);
|
||||
@@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
c_storage->set_data_handle((void*)args.c_ptr);
|
||||
c_mem_desc->dims[0] = args.a_m_size;
|
||||
|
||||
#ifndef VLLM_USE_ACL
|
||||
// We do not support in ACL backend of oneDNN, we handle bias by:
|
||||
// 1. copying it into the result tensor
|
||||
// 2. attaching a fused-sum post-op to the matmul primitive
|
||||
if (args.use_bias) {
|
||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
|
||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
dnnl::matmul matmul = get_matmul_cache(args);
|
||||
|
||||
// With ACL backend of oneDNN, the required memory format might change when the
|
||||
// source tensor dims change. This does not really happen in practice, so isn't
|
||||
// a performance hit, but we need to support it because the API allows for it.
|
||||
#ifdef VLLM_USE_ACL
|
||||
auto new_expected_wei_desc =
|
||||
dnnl::matmul::primitive_desc(
|
||||
const_cast<dnnl_primitive_desc_t>(matmul.get_primitive_desc()))
|
||||
.weights_desc();
|
||||
if (new_expected_wei_desc != b_target_mem_desc_) {
|
||||
prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(),
|
||||
b_target_mem_desc_, new_expected_wei_desc);
|
||||
}
|
||||
#endif
|
||||
|
||||
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
|
||||
scratchpad_storage->set_data_handle(
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
||||
@@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
|
||||
} else {
|
||||
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
|
||||
{key.a_m_stride, 1});
|
||||
#ifdef VLLM_USE_ACL
|
||||
// ACL's backend of oneDNN always expects the weight format to be "any"
|
||||
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
|
||||
dnnl::memory::format_tag::any);
|
||||
#else
|
||||
b_md = b_target_mem_desc_;
|
||||
#endif
|
||||
}
|
||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||
dnnl::memory::format_tag::ab);
|
||||
@@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
|
||||
|
||||
if (key.use_bias) {
|
||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||
// Since ACL's matmuls don't support passing a bias_md, we apply the bias
|
||||
// through a fused-sum post-op
|
||||
#ifdef VLLM_USE_ACL
|
||||
dnnl::post_ops post_ops;
|
||||
post_ops.append_sum();
|
||||
attr.set_post_ops(post_ops);
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||
attr);
|
||||
#else
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||
c_md, attr);
|
||||
#endif
|
||||
} else {
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||
attr);
|
||||
@@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||
|
||||
// ACL matmuls don't support bias_md, so we don't need these
|
||||
#ifndef VLLM_USE_ACL
|
||||
memory_cache_[DNNL_ARG_BIAS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
|
||||
|
||||
#endif
|
||||
memory_cache_[DNNL_ARG_SCRATCHPAD] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
|
||||
}
|
||||
|
||||
bool is_onednn_acl_supported() {
|
||||
#ifdef VLLM_USE_ACL
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user