[CPU] Refactor CPU WNA16 (#28826)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-11-19 10:32:00 +08:00
committed by GitHub
parent 40b6b38f2c
commit 20852c8f4c
22 changed files with 1656 additions and 78 deletions

View File

@@ -1,7 +1,6 @@
#ifndef CPU_ATTN_HPP
#define CPU_ATTN_HPP
#include <unistd.h>
#include <type_traits>
#include <cstddef>
@@ -12,6 +11,7 @@
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
#include "utils.hpp"
namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16 };

View File

@@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
explicit FP16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {}
explicit FP16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
explicit BF16Vec16(bool, void* ptr)
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
@@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(__m512 data) : reg(data) {}
// de-pack 4 bit values
explicit FP32Vec16(int64_t value, const FP32Vec16& lut) {
int64_t mask_0 = 0x0F0F0F0F0F0F0F0F;
int64_t mask_1 = 0xF0F0F0F0F0F0F0F0;
int64_t value_0 = value & mask_0;
int64_t value_1 = value & mask_1;
__m128i vec_0 = _mm_movpi64_epi64((__m64)value_0);
__m128i vec_1 = _mm_movpi64_epi64((__m64)value_1);
vec_0 = _mm_cvtepu8_epi16(vec_0);
vec_1 = _mm_cvtepu8_epi16(vec_1);
vec_1 = _mm_slli_epi16(vec_1, 4);
__m128i vec = _mm_or_si128(vec_0, vec_1);
__m512i vec_i32 = _mm512_cvtepu8_epi32(vec);
reg = _mm512_permutexvar_ps(vec_i32, lut.reg);
}
explicit FP32Vec16(const FP32Vec4& data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
@@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
return _mm512_mask_reduce_add_ps(mask, reg);
}
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
void save(float* ptr, const int elem_num) const {
@@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
_mm512_stream_ps((float*)ptr, vec.reg);
}
static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1,
void* ptr) {
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
vec_1 = _mm512_slli_epi32(vec_1, 16);
vec_0 = _mm512_or_si512(vec_0, vec_1);
_mm512_storeu_epi32(ptr, vec_0);
}
static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1,
void* ptr) {
__m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg);
__m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg);
vec_1 = _mm512_slli_epi32(vec_1, 16);
vec_0 = _mm512_or_si512(vec_0, vec_1);
_mm512_storeu_epi32(ptr, vec_0);
}
#endif
inline void mem_barrier() { _mm_mfence(); }

402
csrc/cpu/cpu_wna16.cpp Normal file
View File

@@ -0,0 +1,402 @@
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
#endif
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
#define VLLM_DISPATCH_CASE_16B_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__))
template <typename T>
void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
namespace {
using cpu_utils::ISA;
using cpu_utils::VecTypeTrait;
template <typename scalar_t, ISA isa, bool has_zp, bool use_desc_act>
class Dequantizer4b {
public:
constexpr static int32_t pack_num = 32 / 4;
using scalar_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
public:
static void dequant(int32_t* __restrict__ q_weight,
scalar_t* __restrict__ weight,
scalar_t* __restrict__ scales,
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
const int64_t scales_stride, const int64_t zeros_stride,
const int32_t k_size, const int32_t group_size) {
vec_op::FP32Vec16 lut;
if constexpr (has_zp) {
// AWQ
alignas(64) static const float LUT[16] = {
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f,
8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f};
lut = vec_op::FP32Vec16(LUT);
} else {
// GPTQ
alignas(64) static const float LUT[16] = {
-8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f};
lut = vec_op::FP32Vec16(LUT);
}
// per 64-bits elem contains 16 output channels
int64_t* __restrict__ curr_q_weight = reinterpret_cast<int64_t*>(q_weight);
int64_t* __restrict__ curr_zeros = reinterpret_cast<int64_t*>(zeros);
scalar_t* __restrict__ curr_weight = weight;
scalar_t* __restrict__ curr_scale = scales;
vec_op::FP32Vec16 scale_0;
vec_op::FP32Vec16 scale_1;
vec_op::FP32Vec16 zero_0;
vec_op::FP32Vec16 zero_1;
int32_t group_counter = 0;
for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) {
int64_t qwb_0 = *curr_q_weight;
int64_t qwb_1 = *(curr_q_weight + 1);
vec_op::FP32Vec16 wb_0(qwb_0, lut);
vec_op::FP32Vec16 wb_1(qwb_1, lut);
if constexpr (!use_desc_act) {
if (group_counter == 0) {
scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale));
scale_1 = vec_op::FP32Vec16(scale_0);
curr_scale += scales_stride;
if constexpr (has_zp) {
zero_0 = vec_op::FP32Vec16(*curr_zeros, lut);
zero_1 = vec_op::FP32Vec16(zero_0);
curr_zeros += zeros_stride / 2;
}
}
} else {
int32_t g_idx_0 = g_idx[k_idx];
int32_t g_idx_1 = g_idx[k_idx + 1];
scale_0 = vec_op::FP32Vec16(
scalar_vec_t(curr_scale + g_idx_0 * scales_stride));
scale_1 = vec_op::FP32Vec16(
scalar_vec_t(curr_scale + g_idx_1 * scales_stride));
if constexpr (has_zp) {
zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2),
lut);
zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2),
lut);
}
}
if constexpr (has_zp) {
wb_0 = wb_0 - zero_0;
wb_1 = wb_1 - zero_1;
}
wb_0 = wb_0 * scale_0;
wb_1 = wb_1 * scale_1;
scalar_vec_t output_vec_0(wb_0);
scalar_vec_t output_vec_1(wb_1);
// AMX needs to interlave K elements to pack as 32 bits
if constexpr (isa == ISA::AMX) {
vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight);
} else {
output_vec_0.save(curr_weight);
output_vec_1.save(curr_weight + 16);
}
// update
curr_q_weight += 2;
curr_weight += 32;
if constexpr (!use_desc_act) {
group_counter += 2;
if (group_counter == group_size) {
group_counter = 0;
}
}
}
}
};
}; // namespace
template <typename scalar_t, typename dequantizer_t, typename gemm_t>
void cpu_gemm_wna16_impl(
scalar_t* __restrict__ input, int32_t* __restrict__ q_weight,
scalar_t* __restrict__ output, scalar_t* __restrict__ scales,
int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx,
scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size,
const int32_t k_size, const int64_t input_stride,
const int64_t output_stride, const int64_t scales_group_stride,
const int64_t zeros_group_stride, const int32_t group_num,
const int32_t group_size, const int64_t pack_factor) {
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
constexpr int32_t n_block_size = 16;
static_assert(gemm_n_tile_size % n_block_size == 0);
const int32_t thread_num = omp_get_max_threads();
// a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks
const int32_t n_partition_size = [&]() {
const int64_t cache_size = cpu_utils::get_l2_size();
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
int64_t ps_thread_limit = n_size / thread_num;
ps_cache_limit =
std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size,
(int64_t)gemm_n_tile_size);
ps_thread_limit =
std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size,
(int64_t)gemm_n_tile_size);
return std::min(ps_cache_limit, ps_thread_limit);
}();
const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size;
// get buffer size
const int64_t b_buffer_size =
(((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64);
const int64_t c_buffer_size =
(((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64);
const int64_t b_buffer_offset = 0;
const int64_t c_buffer_offset = b_buffer_size;
const int64_t buffer_size = b_buffer_size + c_buffer_size;
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size *
thread_num);
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
scalar_t* __restrict__ b_buffer = nullptr;
float* __restrict__ c_buffer = nullptr;
{
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager()
->get_data<uint8_t>() +
thread_id * buffer_size;
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
}
const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size;
const int64_t b_buffer_block_stride = n_block_size * k_size;
const int32_t zeros_block_stride = n_block_size / pack_factor;
gemm_t gemm;
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
const int32_t n_start_idx = task_id * n_partition_size;
const int32_t n_block_start_idx = n_start_idx / n_block_size;
const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx);
const int32_t n_block_num = n_num / n_block_size;
// std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n",
// thread_id, task_id, n_start_idx, n_num);
// dequant weight
{
int32_t* __restrict__ curr_q_weight =
q_weight + n_block_start_idx * q_weight_block_stride;
scalar_t* __restrict__ curr_b_buffer = b_buffer;
scalar_t* __restrict__ curr_scales = scales + n_start_idx;
int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor;
for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) {
dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales,
curr_zeros, g_idx, scales_group_stride,
zeros_group_stride, k_size, group_size);
// if (block_idx == 0 && n_start_idx == 0) {
// print_logits("depacked weight", curr_b_buffer, k_size,
// n_block_size, n_block_size);
// }
// update
curr_q_weight += q_weight_block_stride;
curr_b_buffer += b_buffer_block_stride;
curr_scales += n_block_size;
curr_zeros += zeros_block_stride;
}
}
// compute loop
{
const int32_t n_tile_num = n_num / gemm_n_tile_size;
scalar_t* __restrict__ curr_input = input;
scalar_t* __restrict__ init_bias = bias;
if (bias != nullptr) {
init_bias += n_start_idx;
}
scalar_t* __restrict__ init_output = output + n_start_idx;
for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) {
const int32_t curr_m_size =
std::min(gemm_m_tile_size, m_size - m_idx);
scalar_t* __restrict__ curr_b_buffer = b_buffer;
scalar_t* __restrict__ curr_bias = init_bias;
scalar_t* __restrict__ curr_output = init_output;
for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) {
gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size,
input_stride, b_buffer_block_stride, gemm_n_tile_size,
false);
if (bias != nullptr) {
cpu_micro_gemm::bias_epilogue<gemm_n_tile_size>(
c_buffer, curr_output, curr_bias, curr_m_size,
gemm_n_tile_size, output_stride);
curr_bias += gemm_n_tile_size;
} else {
cpu_micro_gemm::default_epilogue<gemm_n_tile_size>(
c_buffer, curr_output, curr_m_size, gemm_n_tile_size,
output_stride);
}
curr_b_buffer +=
b_buffer_block_stride * (gemm_n_tile_size / n_block_size);
curr_output += gemm_n_tile_size;
}
curr_input += gemm_m_tile_size * input_stride;
init_output += gemm_m_tile_size * output_stride;
}
}
}
}
}
void cpu_gemm_wna16(
const torch::Tensor& input, // [M, K]
const torch::Tensor&
q_weight, // [N / 16, K * 16 / pack_factor], packed as int32
torch::Tensor& output, // [M, N]
const torch::Tensor& scales, // [group_num, N]
const std::optional<torch::Tensor>&
zeros, // [group_num, N / pack_factor], packed as int32
const std::optional<torch::Tensor>& g_idx, // [K]
const std::optional<torch::Tensor>& bias, // [N]
const int64_t pack_factor, const std::string& isa_hint) {
using cpu_utils::ISA;
TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits
const int32_t a_m_size = input.size(0);
const int32_t a_k_size = input.size(1);
const int64_t a_m_stride = input.stride(0);
const int32_t b_n_size = q_weight.size(0) * 16;
TORCH_CHECK_EQ(a_k_size % 32, 0);
TORCH_CHECK_EQ(b_n_size % 32, 0);
const int32_t group_num = scales.size(0);
const int32_t group_size = a_k_size / group_num;
TORCH_CHECK_EQ(group_size % 2, 0);
const int64_t scales_group_stride = scales.stride(0);
const int64_t output_m_stride = output.stride(0);
bool has_zp = zeros.has_value();
bool use_desc_act = g_idx.has_value();
TORCH_CHECK(!(has_zp && use_desc_act));
ISA isa = [&]() {
if (isa_hint == "amx") {
return ISA::AMX;
} else if (isa_hint == "vec") {
return ISA::VEC;
} else {
TORCH_CHECK(false, "unsupported isa hint: " + isa_hint);
}
}();
int32_t* zeros_ptr = has_zp ? zeros->data_ptr<int32_t>() : nullptr;
const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0;
int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr<int32_t>() : nullptr;
VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() {
if (isa == ISA::AMX) {
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::AMX, scalar_t>;
if (has_zp) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, true, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
if (use_desc_act) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, true>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
} else {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::AMX, false, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
} else if (isa == ISA::VEC) {
using gemm_t = cpu_micro_gemm::MicroGemm<ISA::VEC, scalar_t>;
if (has_zp) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, true, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
if (use_desc_act) {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, true>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
} else {
using dequantizer_t = Dequantizer4b<scalar_t, ISA::VEC, false, false>;
cpu_gemm_wna16_impl<scalar_t, dequantizer_t, gemm_t>(
input.data_ptr<scalar_t>(), q_weight.data_ptr<int32_t>(),
output.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), zeros_ptr,
g_idx_ptr, bias.has_value() ? bias->data_ptr<scalar_t>() : nullptr,
a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride,
scales_group_stride, zeros_group_stride, group_num, group_size,
pack_factor);
return;
}
}
});
}

View File

@@ -396,9 +396,9 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
: DNNLMatMulPrimitiveHandler(
static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
m_size_cache_(nullptr) {
assert(ab_type_ == dnnl::memory::data_type::f32 ||
ab_type_ == dnnl::memory::data_type::bf16 ||
ab_type_ == dnnl::memory::data_type::f16);
assert(b_type_ == dnnl::memory::data_type::f32 ||
b_type_ == dnnl::memory::data_type::bf16 ||
b_type_ == dnnl::memory::data_type::f16);
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});

View File

@@ -0,0 +1,245 @@
#ifndef CPU_MICRO_GEMM_AMX_HPP
#define CPU_MICRO_GEMM_AMX_HPP
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
namespace cpu_micro_gemm {
namespace {
// AMX specific
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
typedef struct __tile_config {
uint8_t palette_id = 1;
uint8_t start_row = 0;
uint8_t reserved_0[14] = {0};
uint16_t colsb[16] = {0};
uint8_t rows[16] = {0};
} __tilecfg;
// 2-2-4 pattern, for 16 < m <= 32
// TILE 0, 1: load A matrix, row num should be 16, m - 16
// TILE 2, 3: load B matrix, row num should be 16
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
// - 16
template <typename scalar_t>
class TileGemm224 {
public:
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported data type for TileGemm224");
}
};
template <>
class TileGemm224<c10::BFloat16> {
public:
using scalar_t = c10::BFloat16;
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
c10::BFloat16* __restrict__ a_tile_1 = a_ptr + lda * AMX_TILE_ROW_NUM;
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
// B is always packed as 16 output channels block
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
float* __restrict__ c_tile_4 = c_ptr;
float* __restrict__ c_tile_5 =
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
float* __restrict__ c_tile_6 = c_ptr + AMX_TILE_ROW_NUM * ldc;
float* __restrict__ c_tile_7 =
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
const int32_t c_tile_stride = ldc * sizeof(float);
if (accum_c) {
_tile_loadd(4, c_tile_4, c_tile_stride);
_tile_loadd(5, c_tile_5, c_tile_stride);
_tile_loadd(6, c_tile_6, c_tile_stride);
_tile_loadd(7, c_tile_7, c_tile_stride);
} else {
_tile_zero(4);
_tile_zero(5);
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
_tile_dpbf16ps(4, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
_tile_dpbf16ps(5, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_dpbf16ps(6, 1, 2);
_tile_dpbf16ps(7, 1, 3);
// update ptrs
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
_tile_stored(4, c_tile_4, c_tile_stride);
_tile_stored(5, c_tile_5, c_tile_stride);
_tile_stored(6, c_tile_6, c_tile_stride);
_tile_stored(7, c_tile_7, c_tile_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
const int32_t m_0 = AMX_TILE_ROW_NUM;
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
config.rows[0] = m_0;
config.rows[1] = m_1;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = m_0;
config.rows[5] = m_0;
config.rows[6] = m_1;
config.rows[7] = m_1;
_tile_loadconfig(&config);
}
};
// 1-2-2 pattern, for 0 < m <= 16
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
// m, m
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
// num should be 16
// TILE 6, 7, (6, 7): store results C matrix, row num should be
// m
template <typename scalar_t>
class TileGemm122 {
public:
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
TORCH_CHECK(false, "Unsupported data type for TileGemm122");
}
};
template <>
class TileGemm122<c10::BFloat16> {
public:
using scalar_t = c10::BFloat16;
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
c10::BFloat16* __restrict__ a_tile_0 = a_ptr;
c10::BFloat16* __restrict__ a_tile_1 =
a_ptr + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
const int64_t a_tile_stride = lda * sizeof(c10::BFloat16);
c10::BFloat16* __restrict__ b_tile_2 = b_ptr;
c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride;
c10::BFloat16* __restrict__ b_tile_4 =
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
c10::BFloat16* __restrict__ b_tile_5 =
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
int64_t b_stride = AMX_TILE_ROW_BYTES;
float* __restrict__ c_tile_6 = c_ptr;
float* __restrict__ c_tile_7 = c_ptr + AMX_TILE_ROW_BYTES / sizeof(float);
int64_t c_stride = ldc * sizeof(float);
const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
const int32_t k_group_times = k_times / 2;
const bool has_tail = (k_times % 2 == 1);
if (accum_c) {
_tile_loadd(6, c_tile_6, c_stride);
_tile_loadd(7, c_tile_7, c_stride);
} else {
_tile_zero(6);
_tile_zero(7);
}
for (int32_t k = 0; k < k_group_times; ++k) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
_tile_loadd(1, a_tile_1, a_tile_stride);
_tile_stream_loadd(4, b_tile_4, b_stride);
_tile_dpbf16ps(6, 1, 4);
_tile_stream_loadd(5, b_tile_5, b_stride);
_tile_dpbf16ps(7, 1, 5);
// update ptrs
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
}
if (has_tail) {
_tile_loadd(0, a_tile_0, a_tile_stride);
_tile_stream_loadd(2, b_tile_2, b_stride);
_tile_dpbf16ps(6, 0, 2);
_tile_stream_loadd(3, b_tile_3, b_stride);
_tile_dpbf16ps(7, 0, 3);
}
_tile_stored(6, c_tile_6, c_stride);
_tile_stored(7, c_tile_7, c_stride);
}
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
config.rows[0] = m;
config.rows[1] = m;
config.rows[2] = AMX_TILE_ROW_NUM;
config.rows[3] = AMX_TILE_ROW_NUM;
config.rows[4] = AMX_TILE_ROW_NUM;
config.rows[5] = AMX_TILE_ROW_NUM;
config.rows[6] = m;
config.rows[7] = m;
_tile_loadconfig(&config);
}
};
} // namespace
// Gemm kernel uses AMX, requires B matrix to be packed
template <typename scalar_t>
class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
public:
static constexpr int32_t MaxMSize = 32;
static constexpr int32_t NSize = 32;
public:
MicroGemm() : curr_m_(-1) {
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
}
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
if (m > AMX_TILE_ROW_NUM) {
if (m != curr_m_) {
curr_m_ = m;
TileGemm224<scalar_t>::init_tile_config(m, amx_tile_config_);
}
TileGemm224<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
} else {
if (m != curr_m_) {
curr_m_ = m;
TileGemm122<scalar_t>::init_tile_config(m, amx_tile_config_);
}
TileGemm122<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
}
}
private:
alignas(64) __tilecfg amx_tile_config_;
int32_t curr_m_;
};
} // namespace cpu_micro_gemm
#endif

View File

@@ -0,0 +1,91 @@
#ifndef CPU_MICRO_GEMM_IMPL_HPP
#define CPU_MICRO_GEMM_IMPL_HPP
#include "cpu/utils.hpp"
#include "cpu/cpu_types.hpp"
namespace cpu_micro_gemm {
#define DEFINE_CPU_MICRO_GEMM_PARAMS \
scalar_t *__restrict__ a_ptr, scalar_t *__restrict__ b_ptr, \
float *__restrict__ c_ptr, const int32_t m, const int32_t k, \
const int64_t lda, const int64_t b_n_group_stride, const int64_t ldc, \
const bool accum_c
#define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
template <cpu_utils::ISA isa, typename scalar_t>
class MicroGemm {
public:
static constexpr int32_t MaxMSize = 16;
static constexpr int32_t NSize = 16;
public:
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TORCH_CHECK(false, "Unimplemented MicroGemm.");
}
};
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void default_epilogue(float* __restrict__ c_ptr,
scalar_t* __restrict__ d_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
float* __restrict__ curr_c = c_ptr;
scalar_t* __restrict__ curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* __restrict__ curr_c_iter = curr_c;
scalar_t* __restrict__ curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_size / 16>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
scalar_vec_t c_vec(c_vec_fp32);
c_vec.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
scalar_t* __restrict__ d_ptr,
scalar_t* __restrict__ bias_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
constexpr int32_t n_group_num = n_size / 16;
static_assert(n_group_num <= 16);
vec_op::FP32Vec16 bias_vecs[n_group_num];
scalar_t* __restrict__ curr_bias = bias_ptr;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
scalar_vec_t vec(curr_bias);
bias_vecs[i] = vec_op::FP32Vec16(vec);
curr_bias += 16;
});
float* __restrict__ curr_c = c_ptr;
scalar_t* __restrict__ curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* __restrict__ curr_c_iter = curr_c;
scalar_t* __restrict__ curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
scalar_vec_t c_vec(c_vec_fp32);
c_vec.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
} // namespace cpu_micro_gemm
#endif

View File

@@ -0,0 +1,115 @@
#ifndef CPU_MICRO_GEMM_VEC_HPP
#define CPU_MICRO_GEMM_VEC_HPP
#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp"
namespace cpu_micro_gemm {
namespace {
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
template <typename scalar_t>
class TileGemm82 {
public:
FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
switch (m) {
case 1:
gemm_micro<1>(CPU_MICRO_GEMM_PARAMS);
break;
case 2:
gemm_micro<2>(CPU_MICRO_GEMM_PARAMS);
break;
case 3:
gemm_micro<3>(CPU_MICRO_GEMM_PARAMS);
break;
case 4:
gemm_micro<4>(CPU_MICRO_GEMM_PARAMS);
break;
case 5:
gemm_micro<5>(CPU_MICRO_GEMM_PARAMS);
break;
case 6:
gemm_micro<6>(CPU_MICRO_GEMM_PARAMS);
break;
case 7:
gemm_micro<7>(CPU_MICRO_GEMM_PARAMS);
break;
case 8:
gemm_micro<8>(CPU_MICRO_GEMM_PARAMS);
break;
}
}
template <int32_t M>
static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) {
static_assert(0 < M <= 8);
using load_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
scalar_t* __restrict__ curr_b_0 = b_ptr;
scalar_t* __restrict__ curr_b_1 = b_ptr + b_n_group_stride;
float* __restrict__ curr_c_0 = c_ptr;
float* __restrict__ curr_c_1 = c_ptr + 16;
vec_op::FP32Vec16 c_regs[M * 2];
if (accum_c) {
float* __restrict__ curr_m_c_0 = curr_c_0;
float* __restrict__ curr_m_c_1 = curr_c_1;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
// update
curr_m_c_0 += ldc;
curr_m_c_1 += ldc;
});
}
scalar_t* __restrict__ curr_a = a_ptr;
for (int32_t k_idx = 0; k_idx < k; ++k_idx) {
load_vec_t b_0_reg(curr_b_0);
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
load_vec_t b_1_reg(curr_b_1);
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
scalar_t* __restrict__ curr_m_a = curr_a;
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
scalar_t v = *curr_m_a;
load_vec_t a_reg_original(v);
vec_op::FP32Vec16 a_reg(a_reg_original);
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
// update
curr_m_a += lda;
});
// update
curr_a += 1;
curr_b_0 += 16;
curr_b_1 += 16;
}
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
c_regs[i * 2].save(curr_c_0);
c_regs[i * 2 + 1].save(curr_c_1);
// update
curr_c_0 += ldc;
curr_c_1 += ldc;
});
}
};
} // namespace
// Gemm kernel uses vector instructions, requires B matrix to be packed
template <typename scalar_t>
class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
public:
static constexpr int32_t MaxMSize = 8;
static constexpr int32_t NSize = 32;
public:
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
}
};
} // namespace cpu_micro_gemm
#endif

View File

@@ -103,6 +103,13 @@ void cpu_attention_with_kv_cache(
// Note: just for avoiding importing errors
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
torch::Tensor& output, const torch::Tensor& scales,
const std::optional<torch::Tensor>& zeros,
const std::optional<torch::Tensor>& g_idx,
const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@@ -283,6 +290,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);
// WNA16
#if defined(__AVX512F__)
ops.def(
"cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
"Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
"pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {

55
csrc/cpu/utils.hpp Normal file
View File

@@ -0,0 +1,55 @@
#ifndef UTILS_HPP
#define UTILS_HPP
#include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h>
#include "cpu_types.hpp"
namespace cpu_utils {
enum class ISA { AMX, VEC };
template <typename T>
struct VecTypeTrait {
using vec_t = void;
};
template <>
struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
};
struct Counter {
std::atomic<int64_t> counter;
char _padding[56];
Counter() : counter(0) {}
void reset_counter() { counter.store(0); }
int64_t acquire_counter() { return counter++; }
};
inline int64_t get_l2_size() {
static int64_t size = []() {
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
return l2_cache_size >> 1; // use 50% of L2 cache
}();
return size;
}
} // namespace cpu_utils
#endif