[CPU] Refactor CPU unquantized linear (#24150)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-09-04 14:28:45 +08:00
committed by GitHub
parent cb55ad86fe
commit 57b1ce94f7
9 changed files with 466 additions and 26 deletions

View File

@@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) {
delete ptr;
}
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
template <typename KT, typename VT>
class DNNLPrimitiveCache {
public:
@@ -166,6 +183,23 @@ struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
hash<int>()(static_cast<int>(val.bias_type));
}
};
template <>
struct hash<MatMulPrimitiveHandler::ClassMatmulCacheKey> {
size_t operator()(
const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size);
}
};
template <>
struct hash<MatMulPrimitiveHandler::MSizeCacheKey> {
size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const {
return hash<dnnl_dim_t>()(val.a_m_size) ^
hash<dnnl_dim_t>()(val.a_m_stride) ^ hash<bool>()(val.use_bias) ^
hash<int>()(static_cast<int>(val.bias_type));
}
};
} // namespace std
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
@@ -181,6 +215,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
l.bias_type == r.bias_type;
}
bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size;
}
bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
const MatMulPrimitiveHandler::MSizeCacheKey& r) {
return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride &&
l.use_bias == r.use_bias && l.bias_type == r.bias_type;
}
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
get_w8a8_class_primitive_cache(
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
@@ -239,6 +284,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
}
dnnl::matmul matmul = get_matmul_cache(args);
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
}
@@ -257,6 +307,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
}
@@ -300,6 +352,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
memory_cache_[DNNL_ARG_SCRATCHPAD] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
@@ -319,6 +376,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// For PER_TOKEN, scales will be applied in outside epilogue
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
attr.set_scales_mask(DNNL_ARG_SRC, 0);
@@ -344,3 +404,120 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
attr);
}
}
MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
: DNNLMatMulPrimitiveHandler(
static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
m_size_cache_(nullptr) {
assert(ab_type_ == dnnl::memory::data_type::f32 ||
ab_type_ == dnnl::memory::data_type::bf16 ||
ab_type_ == dnnl::memory::data_type::f16);
prepack_weight(args.b_ptr,
create_primitive_desc(
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true)
.weights_desc());
init_runtime_memory_cache(args);
}
static std::shared_ptr<MatMulPrimitiveHandler::MSizeCache>
get_matul_class_primitive_cache(
const MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
int64_t cache_size) {
static MatMulPrimitiveHandler::ClassMatmulCache cache(128);
assert(cache_size > 0);
return cache.get_or_create(key, [&]() {
return std::make_shared<MatMulPrimitiveHandler::MSizeCache>(cache_size);
});
}
void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
a_storage->set_data_handle((void*)args.a_ptr);
a_mem_desc->dims[0] = args.a_m_size;
a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride;
c_storage->set_data_handle((void*)args.c_ptr);
c_mem_desc->dims[0] = args.a_m_size;
if (args.use_bias) {
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
bias_storage->set_data_handle((void*)args.bias_ptr);
}
dnnl::matmul matmul = get_matmul_cache(args);
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
}
dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
const MSizeCacheKey& key) {
if (m_size_cache_.get() == nullptr) {
ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_};
m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_);
}
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
}
dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
const MSizeCacheKey& key, bool first_time) {
dnnl::memory::desc a_md;
dnnl::memory::desc b_md;
if (first_time) {
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
dnnl::memory::format_tag::ab);
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
dnnl::memory::format_tag::any);
} else {
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
{key.a_m_stride, 1});
b_md = b_target_mem_desc_;
}
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
if (key.use_bias) {
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
c_md, attr);
} else {
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
attr);
}
}
void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
memory_cache_[DNNL_ARG_SRC] = dnnl::memory(
{{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr);
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
memory_cache_[DNNL_ARG_DST] =
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
default_engine(), nullptr);
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
memory_cache_[DNNL_ARG_BIAS] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
memory_cache_[DNNL_ARG_SCRATCHPAD] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}