diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7479c4397..2267718f7 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -73,12 +73,11 @@ function cpu_tests() { pytest -x -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs" - # Note: disable it until supports V1 - # Run AWQ test - # docker exec cpu-test-"$NUMA_NODE" bash -c " - # set -e - # pytest -x -s -v \ - # tests/quantization/test_ipex_quant.py" + # Run AWQ/GPTQ test + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -x -s -v \ + tests/quantization/test_cpu_wna16.py" # Run multi-lora tests docker exec cpu-test-"$NUMA_NODE" bash -c " diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index aa8412581..fbbb03c5e 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -375,6 +375,7 @@ set(VLLM_EXT_SRC if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC "csrc/cpu/shm.cpp" + "csrc/cpu/cpu_wna16.cpp" ${VLLM_EXT_SRC}) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) set(VLLM_EXT_SRC diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 344296528..294b4f714 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -1,7 +1,6 @@ #ifndef CPU_ATTN_HPP #define CPU_ATTN_HPP -#include #include #include @@ -12,6 +11,7 @@ #include "cpu_types.hpp" #include "scratchpad_manager.h" #include "cpu_attn_macros.h" +#include "utils.hpp" namespace cpu_attention { enum class ISA { AMX, VEC, VEC16 }; diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 7ddf028e6..6f51277f7 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -104,6 +104,8 @@ struct FP16Vec16 : public Vec { explicit FP16Vec16(bool, void* ptr) : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} + explicit FP16Vec16(const c10::Half v) : reg(_mm256_set1_epi16(v.x)) {} + explicit FP16Vec16(const FP32Vec16&); void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } @@ -141,6 +143,8 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(bool, void* ptr) : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} + explicit BF16Vec16(const c10::BFloat16 v) : reg(_mm256_set1_epi16(v.x)) {} + explicit BF16Vec16(const FP32Vec16&); void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } @@ -350,6 +354,22 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(__m512 data) : reg(data) {} + // de-pack 4 bit values + explicit FP32Vec16(int64_t value, const FP32Vec16& lut) { + int64_t mask_0 = 0x0F0F0F0F0F0F0F0F; + int64_t mask_1 = 0xF0F0F0F0F0F0F0F0; + int64_t value_0 = value & mask_0; + int64_t value_1 = value & mask_1; + __m128i vec_0 = _mm_movpi64_epi64((__m64)value_0); + __m128i vec_1 = _mm_movpi64_epi64((__m64)value_1); + vec_0 = _mm_cvtepu8_epi16(vec_0); + vec_1 = _mm_cvtepu8_epi16(vec_1); + vec_1 = _mm_slli_epi16(vec_1, 4); + __m128i vec = _mm_or_si128(vec_0, vec_1); + __m512i vec_i32 = _mm512_cvtepu8_epi32(vec); + reg = _mm512_permutexvar_ps(vec_i32, lut.reg); + } + explicit FP32Vec16(const FP32Vec4& data) : reg((__m512)_mm512_inserti32x4( _mm512_inserti32x4( @@ -426,14 +446,6 @@ struct FP32Vec16 : public Vec { float get_last_elem() const { return _mm512_cvtss_f32(reg); } - template - float reduce_sub_sum(int idx) { - static_assert(VEC_ELEM_NUM % group_size == 0); - constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); - __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); - return _mm512_mask_reduce_add_ps(mask, reg); - } - void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); } void save(float* ptr, const int elem_num) const { @@ -755,6 +767,25 @@ inline void non_temporal_save(BF16Vec16& vec, void* ptr) { inline void non_temporal_save(FP32Vec16& vec, void* ptr) { _mm512_stream_ps((float*)ptr, vec.reg); } + +static void interleave_save(const BF16Vec16& vec0, const BF16Vec16& vec1, + void* ptr) { + __m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg); + __m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg); + vec_1 = _mm512_slli_epi32(vec_1, 16); + vec_0 = _mm512_or_si512(vec_0, vec_1); + _mm512_storeu_epi32(ptr, vec_0); +} + +static void interleave_save(const FP16Vec16& vec0, const FP16Vec16& vec1, + void* ptr) { + __m512i vec_0 = _mm512_cvtepu16_epi32(vec0.reg); + __m512i vec_1 = _mm512_cvtepu16_epi32(vec1.reg); + vec_1 = _mm512_slli_epi32(vec_1, 16); + vec_0 = _mm512_or_si512(vec_0, vec_1); + _mm512_storeu_epi32(ptr, vec_0); +} + #endif inline void mem_barrier() { _mm_mfence(); } diff --git a/csrc/cpu/cpu_wna16.cpp b/csrc/cpu/cpu_wna16.cpp new file mode 100644 index 000000000..816d19550 --- /dev/null +++ b/csrc/cpu/cpu_wna16.cpp @@ -0,0 +1,402 @@ +#include "cpu_types.hpp" +#include "scratchpad_manager.h" +#include "utils.hpp" + +#ifdef CPU_CAPABILITY_AMXBF16 + #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp" +#endif +#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp" + +#define VLLM_DISPATCH_CASE_16B_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define VLLM_DISPATCH_16B_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_16B_TYPES(__VA_ARGS__)) + +template +void print_logits(const char* name, T* ptr, int32_t row, int32_t col, + int32_t stride) { + std::stringstream ss; + ss << std::fixed << std::setprecision(5) << name << ": [\n"; + auto* curr_logits_buffer = ptr; + for (int32_t m = 0; m < row; ++m) { + for (int32_t n = 0; n < col; ++n) { + ss << curr_logits_buffer[n] << ", "; + } + ss << "\n"; + curr_logits_buffer += stride; + } + ss << "]\n"; + std::printf("%s", ss.str().c_str()); +} + +namespace { +using cpu_utils::ISA; +using cpu_utils::VecTypeTrait; + +template +class Dequantizer4b { + public: + constexpr static int32_t pack_num = 32 / 4; + using scalar_vec_t = typename VecTypeTrait::vec_t; + + public: + static void dequant(int32_t* __restrict__ q_weight, + scalar_t* __restrict__ weight, + scalar_t* __restrict__ scales, + int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx, + const int64_t scales_stride, const int64_t zeros_stride, + const int32_t k_size, const int32_t group_size) { + vec_op::FP32Vec16 lut; + if constexpr (has_zp) { + // AWQ + alignas(64) static const float LUT[16] = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f}; + lut = vec_op::FP32Vec16(LUT); + } else { + // GPTQ + alignas(64) static const float LUT[16] = { + -8.0f, -7.0f, -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f, + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + lut = vec_op::FP32Vec16(LUT); + } + + // per 64-bits elem contains 16 output channels + int64_t* __restrict__ curr_q_weight = reinterpret_cast(q_weight); + int64_t* __restrict__ curr_zeros = reinterpret_cast(zeros); + scalar_t* __restrict__ curr_weight = weight; + scalar_t* __restrict__ curr_scale = scales; + vec_op::FP32Vec16 scale_0; + vec_op::FP32Vec16 scale_1; + vec_op::FP32Vec16 zero_0; + vec_op::FP32Vec16 zero_1; + int32_t group_counter = 0; + for (int32_t k_idx = 0; k_idx < k_size; k_idx += 2) { + int64_t qwb_0 = *curr_q_weight; + int64_t qwb_1 = *(curr_q_weight + 1); + vec_op::FP32Vec16 wb_0(qwb_0, lut); + vec_op::FP32Vec16 wb_1(qwb_1, lut); + + if constexpr (!use_desc_act) { + if (group_counter == 0) { + scale_0 = vec_op::FP32Vec16(scalar_vec_t(curr_scale)); + scale_1 = vec_op::FP32Vec16(scale_0); + curr_scale += scales_stride; + + if constexpr (has_zp) { + zero_0 = vec_op::FP32Vec16(*curr_zeros, lut); + zero_1 = vec_op::FP32Vec16(zero_0); + curr_zeros += zeros_stride / 2; + } + } + } else { + int32_t g_idx_0 = g_idx[k_idx]; + int32_t g_idx_1 = g_idx[k_idx + 1]; + scale_0 = vec_op::FP32Vec16( + scalar_vec_t(curr_scale + g_idx_0 * scales_stride)); + scale_1 = vec_op::FP32Vec16( + scalar_vec_t(curr_scale + g_idx_1 * scales_stride)); + if constexpr (has_zp) { + zero_0 = vec_op::FP32Vec16(*(curr_zeros + g_idx_0 * zeros_stride / 2), + lut); + zero_1 = vec_op::FP32Vec16(*(curr_zeros + g_idx_1 * zeros_stride / 2), + lut); + } + } + + if constexpr (has_zp) { + wb_0 = wb_0 - zero_0; + wb_1 = wb_1 - zero_1; + } + + wb_0 = wb_0 * scale_0; + wb_1 = wb_1 * scale_1; + + scalar_vec_t output_vec_0(wb_0); + scalar_vec_t output_vec_1(wb_1); + + // AMX needs to interlave K elements to pack as 32 bits + if constexpr (isa == ISA::AMX) { + vec_op::interleave_save(output_vec_0, output_vec_1, curr_weight); + } else { + output_vec_0.save(curr_weight); + output_vec_1.save(curr_weight + 16); + } + + // update + curr_q_weight += 2; + curr_weight += 32; + if constexpr (!use_desc_act) { + group_counter += 2; + if (group_counter == group_size) { + group_counter = 0; + } + } + } + } +}; +}; // namespace + +template +void cpu_gemm_wna16_impl( + scalar_t* __restrict__ input, int32_t* __restrict__ q_weight, + scalar_t* __restrict__ output, scalar_t* __restrict__ scales, + int32_t* __restrict__ zeros, int32_t* __restrict__ g_idx, + scalar_t* __restrict__ bias, const int32_t m_size, const int32_t n_size, + const int32_t k_size, const int64_t input_stride, + const int64_t output_stride, const int64_t scales_group_stride, + const int64_t zeros_group_stride, const int32_t group_num, + const int32_t group_size, const int64_t pack_factor) { + constexpr int32_t gemm_n_tile_size = gemm_t::NSize; + constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize; + constexpr int32_t n_block_size = 16; + static_assert(gemm_n_tile_size % n_block_size == 0); + const int32_t thread_num = omp_get_max_threads(); + + // a simple schedule policy, just to hold more B tiles in L2 and make sure + // each thread has tasks + const int32_t n_partition_size = [&]() { + const int64_t cache_size = cpu_utils::get_l2_size(); + int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t)); + int64_t ps_thread_limit = n_size / thread_num; + ps_cache_limit = + std::max((ps_cache_limit / gemm_n_tile_size) * gemm_n_tile_size, + (int64_t)gemm_n_tile_size); + ps_thread_limit = + std::max((ps_thread_limit / gemm_n_tile_size) * gemm_n_tile_size, + (int64_t)gemm_n_tile_size); + return std::min(ps_cache_limit, ps_thread_limit); + }(); + const int32_t task_num = (n_size + n_partition_size - 1) / n_partition_size; + + // get buffer size + const int64_t b_buffer_size = + (((n_partition_size * k_size * sizeof(scalar_t) + 63) / 64) * 64); + const int64_t c_buffer_size = + (((gemm_m_tile_size * gemm_n_tile_size * sizeof(float) + 63) / 64) * 64); + const int64_t b_buffer_offset = 0; + const int64_t c_buffer_offset = b_buffer_size; + const int64_t buffer_size = b_buffer_size + c_buffer_size; + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size * + thread_num); + + alignas(64) cpu_utils::Counter counter; + cpu_utils::Counter* counter_ptr = &counter; + +#pragma omp parallel for schedule(static, 1) + for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) { + scalar_t* __restrict__ b_buffer = nullptr; + float* __restrict__ c_buffer = nullptr; + { + uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager() + ->get_data() + + thread_id * buffer_size; + b_buffer = reinterpret_cast(buffer_ptr + b_buffer_offset); + c_buffer = reinterpret_cast(buffer_ptr + c_buffer_offset); + } + + const int64_t q_weight_block_stride = n_block_size / pack_factor * k_size; + const int64_t b_buffer_block_stride = n_block_size * k_size; + const int32_t zeros_block_stride = n_block_size / pack_factor; + + gemm_t gemm; + + for (;;) { + int32_t task_id = counter_ptr->acquire_counter(); + + if (task_id >= task_num) { + break; + } + + const int32_t n_start_idx = task_id * n_partition_size; + const int32_t n_block_start_idx = n_start_idx / n_block_size; + const int32_t n_num = std::min(n_partition_size, n_size - n_start_idx); + const int32_t n_block_num = n_num / n_block_size; + // std::printf("thread_id: %d, task_id: %d, n_start_idx: %d, n_num: %d\n", + // thread_id, task_id, n_start_idx, n_num); + + // dequant weight + { + int32_t* __restrict__ curr_q_weight = + q_weight + n_block_start_idx * q_weight_block_stride; + scalar_t* __restrict__ curr_b_buffer = b_buffer; + scalar_t* __restrict__ curr_scales = scales + n_start_idx; + int32_t* __restrict__ curr_zeros = zeros + n_start_idx / pack_factor; + for (int32_t block_idx = 0; block_idx < n_block_num; ++block_idx) { + dequantizer_t::dequant(curr_q_weight, curr_b_buffer, curr_scales, + curr_zeros, g_idx, scales_group_stride, + zeros_group_stride, k_size, group_size); + + // if (block_idx == 0 && n_start_idx == 0) { + // print_logits("depacked weight", curr_b_buffer, k_size, + // n_block_size, n_block_size); + // } + + // update + curr_q_weight += q_weight_block_stride; + curr_b_buffer += b_buffer_block_stride; + curr_scales += n_block_size; + curr_zeros += zeros_block_stride; + } + } + + // compute loop + { + const int32_t n_tile_num = n_num / gemm_n_tile_size; + scalar_t* __restrict__ curr_input = input; + scalar_t* __restrict__ init_bias = bias; + if (bias != nullptr) { + init_bias += n_start_idx; + } + scalar_t* __restrict__ init_output = output + n_start_idx; + for (int32_t m_idx = 0; m_idx < m_size; m_idx += gemm_m_tile_size) { + const int32_t curr_m_size = + std::min(gemm_m_tile_size, m_size - m_idx); + scalar_t* __restrict__ curr_b_buffer = b_buffer; + scalar_t* __restrict__ curr_bias = init_bias; + scalar_t* __restrict__ curr_output = init_output; + for (int32_t n_tile_idx = 0; n_tile_idx < n_tile_num; ++n_tile_idx) { + gemm.gemm(curr_input, curr_b_buffer, c_buffer, curr_m_size, k_size, + input_stride, b_buffer_block_stride, gemm_n_tile_size, + false); + + if (bias != nullptr) { + cpu_micro_gemm::bias_epilogue( + c_buffer, curr_output, curr_bias, curr_m_size, + gemm_n_tile_size, output_stride); + curr_bias += gemm_n_tile_size; + } else { + cpu_micro_gemm::default_epilogue( + c_buffer, curr_output, curr_m_size, gemm_n_tile_size, + output_stride); + } + + curr_b_buffer += + b_buffer_block_stride * (gemm_n_tile_size / n_block_size); + curr_output += gemm_n_tile_size; + } + curr_input += gemm_m_tile_size * input_stride; + init_output += gemm_m_tile_size * output_stride; + } + } + } + } +} + +void cpu_gemm_wna16( + const torch::Tensor& input, // [M, K] + const torch::Tensor& + q_weight, // [N / 16, K * 16 / pack_factor], packed as int32 + torch::Tensor& output, // [M, N] + const torch::Tensor& scales, // [group_num, N] + const std::optional& + zeros, // [group_num, N / pack_factor], packed as int32 + const std::optional& g_idx, // [K] + const std::optional& bias, // [N] + const int64_t pack_factor, const std::string& isa_hint) { + using cpu_utils::ISA; + TORCH_CHECK_EQ(pack_factor, 8); // only supports 4bits + const int32_t a_m_size = input.size(0); + const int32_t a_k_size = input.size(1); + const int64_t a_m_stride = input.stride(0); + const int32_t b_n_size = q_weight.size(0) * 16; + TORCH_CHECK_EQ(a_k_size % 32, 0); + TORCH_CHECK_EQ(b_n_size % 32, 0); + const int32_t group_num = scales.size(0); + const int32_t group_size = a_k_size / group_num; + TORCH_CHECK_EQ(group_size % 2, 0); + const int64_t scales_group_stride = scales.stride(0); + const int64_t output_m_stride = output.stride(0); + + bool has_zp = zeros.has_value(); + bool use_desc_act = g_idx.has_value(); + TORCH_CHECK(!(has_zp && use_desc_act)); + + ISA isa = [&]() { + if (isa_hint == "amx") { + return ISA::AMX; + } else if (isa_hint == "vec") { + return ISA::VEC; + } else { + TORCH_CHECK(false, "unsupported isa hint: " + isa_hint); + } + }(); + + int32_t* zeros_ptr = has_zp ? zeros->data_ptr() : nullptr; + const int64_t zeros_group_stride = has_zp ? zeros->stride(0) : 0; + int32_t* g_idx_ptr = use_desc_act ? g_idx->data_ptr() : nullptr; + + VLLM_DISPATCH_16B_TYPES(input.scalar_type(), "cpu_gemm_wna16", [&]() { + if (isa == ISA::AMX) { + using gemm_t = cpu_micro_gemm::MicroGemm; + if (has_zp) { + using dequantizer_t = Dequantizer4b; + cpu_gemm_wna16_impl( + input.data_ptr(), q_weight.data_ptr(), + output.data_ptr(), scales.data_ptr(), zeros_ptr, + g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, + a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, + scales_group_stride, zeros_group_stride, group_num, group_size, + pack_factor); + return; + } + if (use_desc_act) { + using dequantizer_t = Dequantizer4b; + cpu_gemm_wna16_impl( + input.data_ptr(), q_weight.data_ptr(), + output.data_ptr(), scales.data_ptr(), zeros_ptr, + g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, + a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, + scales_group_stride, zeros_group_stride, group_num, group_size, + pack_factor); + return; + } else { + using dequantizer_t = Dequantizer4b; + cpu_gemm_wna16_impl( + input.data_ptr(), q_weight.data_ptr(), + output.data_ptr(), scales.data_ptr(), zeros_ptr, + g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, + a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, + scales_group_stride, zeros_group_stride, group_num, group_size, + pack_factor); + return; + } + } else if (isa == ISA::VEC) { + using gemm_t = cpu_micro_gemm::MicroGemm; + if (has_zp) { + using dequantizer_t = Dequantizer4b; + cpu_gemm_wna16_impl( + input.data_ptr(), q_weight.data_ptr(), + output.data_ptr(), scales.data_ptr(), zeros_ptr, + g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, + a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, + scales_group_stride, zeros_group_stride, group_num, group_size, + pack_factor); + return; + } + if (use_desc_act) { + using dequantizer_t = Dequantizer4b; + cpu_gemm_wna16_impl( + input.data_ptr(), q_weight.data_ptr(), + output.data_ptr(), scales.data_ptr(), zeros_ptr, + g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, + a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, + scales_group_stride, zeros_group_stride, group_num, group_size, + pack_factor); + return; + } else { + using dequantizer_t = Dequantizer4b; + cpu_gemm_wna16_impl( + input.data_ptr(), q_weight.data_ptr(), + output.data_ptr(), scales.data_ptr(), zeros_ptr, + g_idx_ptr, bias.has_value() ? bias->data_ptr() : nullptr, + a_m_size, b_n_size, a_k_size, a_m_stride, output_m_stride, + scales_group_stride, zeros_group_stride, group_num, group_size, + pack_factor); + return; + } + } + }); +} diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index 02a8072cc..cfb6e78cb 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -396,9 +396,9 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) : DNNLMatMulPrimitiveHandler( static_cast(args), args.ab_type), m_size_cache_(nullptr) { - assert(ab_type_ == dnnl::memory::data_type::f32 || - ab_type_ == dnnl::memory::data_type::bf16 || - ab_type_ == dnnl::memory::data_type::f16); + assert(b_type_ == dnnl::memory::data_type::f32 || + b_type_ == dnnl::memory::data_type::bf16 || + b_type_ == dnnl::memory::data_type::f16); dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, {b_k_stride_, b_n_stride_}); diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp new file mode 100644 index 000000000..87a019773 --- /dev/null +++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp @@ -0,0 +1,245 @@ +#ifndef CPU_MICRO_GEMM_AMX_HPP +#define CPU_MICRO_GEMM_AMX_HPP +#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp" + +namespace cpu_micro_gemm { +namespace { +// AMX specific +constexpr static int64_t AMX_TILE_ROW_BYTES = 64; +constexpr static int64_t AMX_TILE_ROW_NUM = 16; +constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM; + +typedef struct __tile_config { + uint8_t palette_id = 1; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +} __tilecfg; + +// 2-2-4 pattern, for 16 < m <= 32 +// TILE 0, 1: load A matrix, row num should be 16, m - 16 +// TILE 2, 3: load B matrix, row num should be 16 +// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m +// - 16 +template +class TileGemm224 { + public: + FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + TORCH_CHECK(false, "Unsupported data type for TileGemm224"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported data type for TileGemm224"); + } +}; + +template <> +class TileGemm224 { + public: + using scalar_t = c10::BFloat16; + FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + c10::BFloat16* __restrict__ a_tile_0 = a_ptr; + c10::BFloat16* __restrict__ a_tile_1 = a_ptr + lda * AMX_TILE_ROW_NUM; + const int64_t a_tile_stride = lda * sizeof(c10::BFloat16); + + // B is always packed as 16 output channels block + c10::BFloat16* __restrict__ b_tile_2 = b_ptr; + c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride; + const int32_t b_tile_stride = AMX_TILE_ROW_BYTES; + + float* __restrict__ c_tile_4 = c_ptr; + float* __restrict__ c_tile_5 = + c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float); + float* __restrict__ c_tile_6 = c_ptr + AMX_TILE_ROW_NUM * ldc; + float* __restrict__ c_tile_7 = + c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float); + const int32_t c_tile_stride = ldc * sizeof(float); + + if (accum_c) { + _tile_loadd(4, c_tile_4, c_tile_stride); + _tile_loadd(5, c_tile_5, c_tile_stride); + _tile_loadd(6, c_tile_6, c_tile_stride); + _tile_loadd(7, c_tile_7, c_tile_stride); + } else { + _tile_zero(4); + _tile_zero(5); + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_tile_stride); + _tile_dpbf16ps(4, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_tile_stride); + _tile_dpbf16ps(5, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_dpbf16ps(6, 1, 2); + _tile_dpbf16ps(7, 1, 3); + + // update ptrs + a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + _tile_stored(4, c_tile_4, c_tile_stride); + _tile_stored(5, c_tile_5, c_tile_stride); + _tile_stored(6, c_tile_6, c_tile_stride); + _tile_stored(7, c_tile_7, c_tile_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + const int32_t m_0 = AMX_TILE_ROW_NUM; + const int32_t m_1 = m - AMX_TILE_ROW_NUM; + config.rows[0] = m_0; + config.rows[1] = m_1; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = m_0; + config.rows[5] = m_0; + config.rows[6] = m_1; + config.rows[7] = m_1; + _tile_loadconfig(&config); + } +}; + +// 1-2-2 pattern, for 0 < m <= 16 +// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be +// m, m +// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row +// num should be 16 +// TILE 6, 7, (6, 7): store results C matrix, row num should be +// m +template +class TileGemm122 { + public: + FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + TORCH_CHECK(false, "Unsupported data type for TileGemm122"); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + TORCH_CHECK(false, "Unsupported data type for TileGemm122"); + } +}; + +template <> +class TileGemm122 { + public: + using scalar_t = c10::BFloat16; + FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + c10::BFloat16* __restrict__ a_tile_0 = a_ptr; + c10::BFloat16* __restrict__ a_tile_1 = + a_ptr + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + const int64_t a_tile_stride = lda * sizeof(c10::BFloat16); + + c10::BFloat16* __restrict__ b_tile_2 = b_ptr; + c10::BFloat16* __restrict__ b_tile_3 = b_ptr + b_n_group_stride; + c10::BFloat16* __restrict__ b_tile_4 = + b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + c10::BFloat16* __restrict__ b_tile_5 = + b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); + int64_t b_stride = AMX_TILE_ROW_BYTES; + + float* __restrict__ c_tile_6 = c_ptr; + float* __restrict__ c_tile_7 = c_ptr + AMX_TILE_ROW_BYTES / sizeof(float); + int64_t c_stride = ldc * sizeof(float); + + const int32_t k_times = k / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + const int32_t k_group_times = k_times / 2; + const bool has_tail = (k_times % 2 == 1); + + if (accum_c) { + _tile_loadd(6, c_tile_6, c_stride); + _tile_loadd(7, c_tile_7, c_stride); + } else { + _tile_zero(6); + _tile_zero(7); + } + + for (int32_t k = 0; k < k_group_times; ++k) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + _tile_loadd(1, a_tile_1, a_tile_stride); + _tile_stream_loadd(4, b_tile_4, b_stride); + _tile_dpbf16ps(6, 1, 4); + _tile_stream_loadd(5, b_tile_5, b_stride); + _tile_dpbf16ps(7, 1, 5); + + // update ptrs + a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); + b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); + } + + if (has_tail) { + _tile_loadd(0, a_tile_0, a_tile_stride); + _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_dpbf16ps(6, 0, 2); + _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_dpbf16ps(7, 0, 3); + } + + _tile_stored(6, c_tile_6, c_stride); + _tile_stored(7, c_tile_7, c_stride); + } + + FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { + config.rows[0] = m; + config.rows[1] = m; + config.rows[2] = AMX_TILE_ROW_NUM; + config.rows[3] = AMX_TILE_ROW_NUM; + config.rows[4] = AMX_TILE_ROW_NUM; + config.rows[5] = AMX_TILE_ROW_NUM; + config.rows[6] = m; + config.rows[7] = m; + _tile_loadconfig(&config); + } +}; +} // namespace + +// Gemm kernel uses AMX, requires B matrix to be packed +template +class MicroGemm { + public: + static constexpr int32_t MaxMSize = 32; + static constexpr int32_t NSize = 32; + + public: + MicroGemm() : curr_m_(-1) { + vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; }); + } + + void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + if (m > AMX_TILE_ROW_NUM) { + if (m != curr_m_) { + curr_m_ = m; + TileGemm224::init_tile_config(m, amx_tile_config_); + } + TileGemm224::gemm(CPU_MICRO_GEMM_PARAMS); + } else { + if (m != curr_m_) { + curr_m_ = m; + TileGemm122::init_tile_config(m, amx_tile_config_); + } + TileGemm122::gemm(CPU_MICRO_GEMM_PARAMS); + } + } + + private: + alignas(64) __tilecfg amx_tile_config_; + int32_t curr_m_; +}; + +} // namespace cpu_micro_gemm + +#endif diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp new file mode 100644 index 000000000..784da55a4 --- /dev/null +++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp @@ -0,0 +1,91 @@ +#ifndef CPU_MICRO_GEMM_IMPL_HPP +#define CPU_MICRO_GEMM_IMPL_HPP +#include "cpu/utils.hpp" +#include "cpu/cpu_types.hpp" + +namespace cpu_micro_gemm { +#define DEFINE_CPU_MICRO_GEMM_PARAMS \ + scalar_t *__restrict__ a_ptr, scalar_t *__restrict__ b_ptr, \ + float *__restrict__ c_ptr, const int32_t m, const int32_t k, \ + const int64_t lda, const int64_t b_n_group_stride, const int64_t ldc, \ + const bool accum_c + +#define CPU_MICRO_GEMM_PARAMS \ + a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c + +template +class MicroGemm { + public: + static constexpr int32_t MaxMSize = 16; + static constexpr int32_t NSize = 16; + + public: + void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + TORCH_CHECK(false, "Unimplemented MicroGemm."); + } +}; + +template +FORCE_INLINE void default_epilogue(float* __restrict__ c_ptr, + scalar_t* __restrict__ d_ptr, + const int32_t m, const int64_t ldc, + const int64_t ldd) { + using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + static_assert(n_size % 16 == 0); + + float* __restrict__ curr_c = c_ptr; + scalar_t* __restrict__ curr_d = d_ptr; + for (int32_t i = 0; i < m; ++i) { + float* __restrict__ curr_c_iter = curr_c; + scalar_t* __restrict__ curr_d_iter = curr_d; + vec_op::unroll_loop([&](int32_t n_g_idx) { + vec_op::FP32Vec16 c_vec_fp32(curr_c_iter); + scalar_vec_t c_vec(c_vec_fp32); + c_vec.save(curr_d_iter); + curr_c_iter += 16; + curr_d_iter += 16; + }); + curr_c += ldc; + curr_d += ldd; + } +} + +template +FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr, + scalar_t* __restrict__ d_ptr, + scalar_t* __restrict__ bias_ptr, + const int32_t m, const int64_t ldc, + const int64_t ldd) { + using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + static_assert(n_size % 16 == 0); + constexpr int32_t n_group_num = n_size / 16; + static_assert(n_group_num <= 16); + + vec_op::FP32Vec16 bias_vecs[n_group_num]; + scalar_t* __restrict__ curr_bias = bias_ptr; + vec_op::unroll_loop([&](int32_t i) { + scalar_vec_t vec(curr_bias); + bias_vecs[i] = vec_op::FP32Vec16(vec); + curr_bias += 16; + }); + + float* __restrict__ curr_c = c_ptr; + scalar_t* __restrict__ curr_d = d_ptr; + for (int32_t i = 0; i < m; ++i) { + float* __restrict__ curr_c_iter = curr_c; + scalar_t* __restrict__ curr_d_iter = curr_d; + vec_op::unroll_loop([&](int32_t n_g_idx) { + vec_op::FP32Vec16 c_vec_fp32(curr_c_iter); + c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx]; + scalar_vec_t c_vec(c_vec_fp32); + c_vec.save(curr_d_iter); + curr_c_iter += 16; + curr_d_iter += 16; + }); + curr_c += ldc; + curr_d += ldd; + } +} +} // namespace cpu_micro_gemm + +#endif diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp new file mode 100644 index 000000000..3985c2f2e --- /dev/null +++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp @@ -0,0 +1,115 @@ +#ifndef CPU_MICRO_GEMM_VEC_HPP +#define CPU_MICRO_GEMM_VEC_HPP +#include "cpu/micro_gemm/cpu_micro_gemm_impl.hpp" + +namespace cpu_micro_gemm { +namespace { +// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32] +template +class TileGemm82 { + public: + FORCE_INLINE static void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + switch (m) { + case 1: + gemm_micro<1>(CPU_MICRO_GEMM_PARAMS); + break; + case 2: + gemm_micro<2>(CPU_MICRO_GEMM_PARAMS); + break; + case 3: + gemm_micro<3>(CPU_MICRO_GEMM_PARAMS); + break; + case 4: + gemm_micro<4>(CPU_MICRO_GEMM_PARAMS); + break; + case 5: + gemm_micro<5>(CPU_MICRO_GEMM_PARAMS); + break; + case 6: + gemm_micro<6>(CPU_MICRO_GEMM_PARAMS); + break; + case 7: + gemm_micro<7>(CPU_MICRO_GEMM_PARAMS); + break; + case 8: + gemm_micro<8>(CPU_MICRO_GEMM_PARAMS); + break; + } + } + + template + static void gemm_micro(DEFINE_CPU_MICRO_GEMM_PARAMS) { + static_assert(0 < M <= 8); + using load_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + + scalar_t* __restrict__ curr_b_0 = b_ptr; + scalar_t* __restrict__ curr_b_1 = b_ptr + b_n_group_stride; + float* __restrict__ curr_c_0 = c_ptr; + float* __restrict__ curr_c_1 = c_ptr + 16; + + vec_op::FP32Vec16 c_regs[M * 2]; + if (accum_c) { + float* __restrict__ curr_m_c_0 = curr_c_0; + float* __restrict__ curr_m_c_1 = curr_c_1; + vec_op::unroll_loop([&](int32_t i) { + c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0); + c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1); + + // update + curr_m_c_0 += ldc; + curr_m_c_1 += ldc; + }); + } + + scalar_t* __restrict__ curr_a = a_ptr; + for (int32_t k_idx = 0; k_idx < k; ++k_idx) { + load_vec_t b_0_reg(curr_b_0); + vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg); + load_vec_t b_1_reg(curr_b_1); + vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg); + + scalar_t* __restrict__ curr_m_a = curr_a; + vec_op::unroll_loop([&](int32_t i) { + scalar_t v = *curr_m_a; + load_vec_t a_reg_original(v); + vec_op::FP32Vec16 a_reg(a_reg_original); + c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg; + c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg; + + // update + curr_m_a += lda; + }); + + // update + curr_a += 1; + curr_b_0 += 16; + curr_b_1 += 16; + } + + vec_op::unroll_loop([&](int32_t i) { + c_regs[i * 2].save(curr_c_0); + c_regs[i * 2 + 1].save(curr_c_1); + + // update + curr_c_0 += ldc; + curr_c_1 += ldc; + }); + } +}; +} // namespace + +// Gemm kernel uses vector instructions, requires B matrix to be packed +template +class MicroGemm { + public: + static constexpr int32_t MaxMSize = 8; + static constexpr int32_t NSize = 32; + + public: + void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { + TileGemm82::gemm(CPU_MICRO_GEMM_PARAMS); + } +}; +} // namespace cpu_micro_gemm + +#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 9fefd88cd..b07d20bab 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -103,6 +103,13 @@ void cpu_attention_with_kv_cache( // Note: just for avoiding importing errors void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); } +void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight, + torch::Tensor& output, const torch::Tensor& scales, + const std::optional& zeros, + const std::optional& g_idx, + const std::optional& bias, + const int64_t pack_factor, const std::string& isa_hint); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -283,6 +290,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("static_scaled_fp8_quant() -> ()", placeholder_op); ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op); ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op); + + // WNA16 +#if defined(__AVX512F__) + ops.def( + "cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, " + "Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt " + "pack_factor, str isa_hint) -> ()"); + ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { diff --git a/csrc/cpu/utils.hpp b/csrc/cpu/utils.hpp new file mode 100644 index 000000000..d8399c56f --- /dev/null +++ b/csrc/cpu/utils.hpp @@ -0,0 +1,55 @@ +#ifndef UTILS_HPP +#define UTILS_HPP + +#include +#include +#include +#include + +#include "cpu_types.hpp" + +namespace cpu_utils { +enum class ISA { AMX, VEC }; + +template +struct VecTypeTrait { + using vec_t = void; +}; + +template <> +struct VecTypeTrait { + using vec_t = vec_op::FP32Vec16; +}; + +template <> +struct VecTypeTrait { + using vec_t = vec_op::BF16Vec16; +}; + +template <> +struct VecTypeTrait { + using vec_t = vec_op::FP16Vec16; +}; + +struct Counter { + std::atomic counter; + char _padding[56]; + + Counter() : counter(0) {} + + void reset_counter() { counter.store(0); } + + int64_t acquire_counter() { return counter++; } +}; + +inline int64_t get_l2_size() { + static int64_t size = []() { + long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE); + assert(l2_cache_size != -1); + return l2_cache_size >> 1; // use 50% of L2 cache + }(); + return size; +} +} // namespace cpu_utils + +#endif diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index be99cef37..d1beab785 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -97,7 +97,6 @@ Currently, there are no pre-built CPU wheels. - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. - `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence. -- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). - `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). ## FAQ @@ -191,10 +190,9 @@ vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel - GPTQ (x86 only) - compressed-tensor INT8 W8A8 (x86, s390x) -### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? +### (x86 only) What is the purpose of `VLLM_CPU_SGL_KERNEL`? - Both of them require `amx` CPU flag. - - `VLLM_CPU_MOE_PREPACK` can provide better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provide better performance for MoE models and small-batch scenarios. ### Why do I see `get_mempolicy: Operation not permitted` when running in Docker? diff --git a/requirements/cpu.txt b/requirements/cpu.txt index d11787df4..e23d3286f 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -22,7 +22,6 @@ datasets # for benchmark scripts # Intel Extension for PyTorch, only for x86_64 CPUs intel-openmp==2024.2.1; platform_machine == "x86_64" -intel_extension_for_pytorch==2.8.0; platform_machine == "x86_64" triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile. # Use this to gather CPU info and optimize based on ARM Neoverse cores diff --git a/tests/quantization/test_cpu_wna16.py b/tests/quantization/test_cpu_wna16.py new file mode 100644 index 000000000..077b802e5 --- /dev/null +++ b/tests/quantization/test_cpu_wna16.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.platforms import current_platform + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +MODELS = [ + "TheBloke/TinyLlama-1.1B-Chat-v1.0-AWQ", + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", # with g_idx +] +DTYPE = ["bfloat16"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", DTYPE) +def test_ipex_quant(vllm_runner, model, dtype): + with vllm_runner(model, dtype=dtype) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + assert output + print(output) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 096266c97..66cf6472e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2702,6 +2702,31 @@ def cpu_attention_with_kv_cache( ) +def cpu_gemm_wna16( + input: torch.Tensor, + q_weight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor | None, + g_idx: torch.Tensor | None, + bias: torch.Tensor | None, + pack_factor: int, + isa_hint: str, +) -> torch.Tensor: + output = torch.empty((input.size(0), scales.size(1)), dtype=input.dtype) + torch.ops._C.cpu_gemm_wna16( + input, + q_weight, + output, + scales, + zeros, + g_idx, + bias, + pack_factor, + isa_hint, + ) + return output + + if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): @register_fake("_qutlass_C::matmul_mxf4_bf16_tn") diff --git a/vllm/config/model.py b/vllm/config/model.py index 49fe0bcd9..3e8790a26 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1020,6 +1020,8 @@ class ModelConfig: # Ensure heavy backends are probed last to avoid unnecessary # imports during override detection (e.g., MXFP4 imports Triton) "mxfp4", + "cpu_gptq", + "cpu_awq", ] quantization_methods = [ q for q in supported_quantization if q not in overrides diff --git a/vllm/envs.py b/vllm/envs.py index 62b3344cc..6d92d5afe 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -50,7 +50,6 @@ if TYPE_CHECKING: VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None - VLLM_CPU_MOE_PREPACK: bool = True VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False @@ -665,10 +664,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ) if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None, - # (CPU backend only) whether to use prepack for MoE layer. This will be - # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might - # need to set this to "0" (False). - "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), # If the env var is set, Ray Compiled Graph uses the specified diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 23ace3408..572307052 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -6,7 +6,6 @@ import torch from torch.nn import functional as F from vllm import _custom_ops as ops -from vllm import envs def silu_and_mul(x: torch.Tensor) -> torch.Tensor: @@ -130,54 +129,6 @@ def select_experts( ) -class IPEXFusedMOE: - def __init__(self, layer: torch.nn.Module) -> None: - import intel_extension_for_pytorch as ipex - - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=envs.VLLM_CPU_MOE_PREPACK, - ) - - def __call__( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - ) -> torch.Tensor: - assert activation == "silu", f"{activation} is not supported." - assert not apply_router_weight_on_input - assert routed_scaling_factor == 1.0, ( - f"routed_scaling_factor {routed_scaling_factor} is not supported." - ) - return layer.ipex_fusion( - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - custom_routing_function, - scoring_func, - e_score_correction_bias, - ) - - class SGLFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: pass diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index ce56887f1..2e0376553 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -260,7 +260,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w2_weight.copy_(packed_w2_weight) layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) else: - layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) + layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) else: layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index bb42b10f8..18aaae394 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -38,6 +38,8 @@ QuantizationMethods = Literal[ "inc", "mxfp4", "petit_nvfp4", + "cpu_gptq", + "cpu_awq", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -107,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, ) + from .cpu_wna16 import CPUAWQConfig, CPUGPTQConfig from .deepspeedfp import DeepSpeedFPConfig from .experts_int8 import ExpertsInt8Config from .fbgemm_fp8 import FBGEMMFp8Config @@ -159,6 +162,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "inc": INCConfig, "mxfp4": Mxfp4Config, "petit_nvfp4": PetitNvFp4Config, + "cpu_gptq": CPUGPTQConfig, + "cpu_awq": CPUAWQConfig, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/cpu_wna16.py b/vllm/model_executor/layers/quantization/cpu_wna16.py new file mode 100644 index 000000000..bf643f55f --- /dev/null +++ b/vllm/model_executor/layers/quantization/cpu_wna16.py @@ -0,0 +1,625 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE + +from vllm._custom_ops import ( + cpu_gemm_wna16, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, + pack_cols, + unpack_cols, +) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_safetensors_params_metadata +from vllm.utils.collection_utils import is_list_of + +logger = init_logger(__name__) + + +class CPUGPTQConfig(QuantizationConfig): + """Config class for CPU GPTQ quant""" + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + dynamic: dict[str, dict[str, int | bool]], + full_config: dict[str, Any], + modules_in_block_to_quantize: list[str] | None = None, + ) -> None: + super().__init__() + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is dict[str, dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + assert weight_bits == 4 + self.dynamic = dynamic + self.weight_bits = weight_bits + self.is_sym = is_sym + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.full_config = full_config + self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] + + def __repr__(self) -> str: + return ( + f"CPUWNA16Config(" + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"dynamic={self.dynamic}, " + f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})" + ) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "cpu_gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return -1 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "CPUGPTQConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + group_size = cls.get_from_keys(config, ["group_size"]) + is_sym = cls.get_from_keys(config, ["sym"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + modules_in_block_to_quantize = cls.get_from_keys_or( + config, ["modules_in_block_to_quantize"], default=None + ) + return cls( + weight_bits, + group_size, + desc_act, + is_sym, + lm_head_quantized, + dynamic, + config, + modules_in_block_to_quantize, + ) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant + ) -> QuantizationMethods | None: + quant_method = hf_quant_cfg.get("quant_method", "").lower() + if current_platform.is_cpu() and (quant_method == "gptq"): + return cls.get_name() + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + return get_linear_quant_method(self, layer, prefix, CPUGPTQLinearMethod) # type: ignore + + def apply_vllm_mapper(self, hf_to_vllm_mapper): + if self.modules_in_block_to_quantize is not None: + self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( + self.modules_in_block_to_quantize + ) + + def maybe_update_config(self, model_name: str, revision: str | None = None): + if self.modules_in_block_to_quantize: + if is_list_of(self.modules_in_block_to_quantize, list): + # original modules_in_block_to_quantize: list[list[str]] + # flatten original modules_in_block_to_quantize + self.modules_in_block_to_quantize = [ + item + for sublist in self.modules_in_block_to_quantize + for item in sublist + ] + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_in_block_to_quantize = list(quant_layers) + + +class CPUGPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ on CPU. + + Args: + quant_config: The CPUWNA16 quantization config. + """ + + def __init__(self, quant_config: CPUGPTQConfig) -> None: + self.quant_config = quant_config + assert self.quant_config.is_sym, "GPTQ asym quant is not supported on CPU" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + assert output_size_per_partition * self.quant_config.weight_bits % 32 == 0 + assert output_size_per_partition % 32 == 0 + assert input_size_per_partition % 32 == 0 + + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if marlin_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each rank in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + # Activation order + g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) + set_weight_attrs( + g_idx, + {"ignore_warning": True}, + ) + + qzeros_args = { + "data": torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": weight_loader, + } + weight_scale_args = { + "data": torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": weight_loader, + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) + + else: + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False) + packed_weight = layer.qweight.data + bits = self.quant_config.weight_bits + pack_factor = int(self.quant_config.pack_factor) + p_w_k, p_w_n = packed_weight.size() + input_size = p_w_k * pack_factor + output_size = p_w_n + isa_hint = _get_isa_hint(layer.scales.dtype) + layer.isa_hint = isa_hint + + layer.qzeros = None + if not self.quant_config.desc_act: + layer.g_idx = None + + # convert input dim packed to output dim packed + weight = unpack_cols(packed_weight, bits, p_w_k, p_w_n * pack_factor).view( + p_w_k, p_w_n, pack_factor + ) + weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous() + weight = pack_cols(weight, bits, input_size, output_size) + # make 16 output channel as a block and transpose to the make + # the block contigous + weight = ( + weight.view(input_size, -1, 16 // pack_factor) + .permute(1, 0, 2) + .reshape(-1, input_size * 16 // pack_factor) + .contiguous() + ) + layer.qweight.data = weight + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + x = cpu_gemm_wna16( + input=x, + q_weight=layer.qweight, + scales=layer.scales, + zeros=layer.qzeros, + g_idx=layer.g_idx, + bias=bias, + pack_factor=8, + isa_hint=layer.isa_hint, + ) + return x + + +class CPUAWQConfig(QuantizationConfig): + """Config class for CPU AWQ""" + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: list[str] | None, + full_config: dict[str, Any], + ) -> None: + super().__init__() + assert weight_bits == 4 + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.zero_point = zero_point + self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits + self.modules_to_not_convert = modules_to_not_convert or [] + self.full_config = full_config + + def __repr__(self) -> str: + return ( + f"AWQMarlinConfig(" + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) + + @classmethod + def get_name(cls) -> "QuantizationMethods": + return "cpu_awq" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return -1 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "CPUAWQConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls( + weight_bits, + group_size, + zero_point, + lm_head_quantized, + modules_to_not_convert, + config, + ) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant + ) -> Optional["QuantizationMethods"]: + quant_method = hf_quant_cfg.get("quant_method", "").lower() + if current_platform.is_cpu() and (quant_method == "awq"): + return cls.get_name() + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): + return UnquantizedLinearMethod() + return CPUAWQLinearMethod(self) + return None + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.modules_to_not_convert: + self.modules_to_not_convert = hf_to_vllm_mapper.apply_list( + self.modules_to_not_convert + ) + + def maybe_update_config(self, model_name: str, revision: str | None = None): + if self.modules_to_not_convert: + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + layers = {param_name.rsplit(".", 1)[0] for param_name in metadata} + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_to_not_convert = list(layers - quant_layers) + + +class CPUAWQLinearMethod(LinearMethodBase): + """Linear method for CPU AWQ. + + Args: + quant_config: The CPU AWQ quantization config. + """ + + def __init__(self, quant_config: CPUAWQConfig) -> None: + self.quant_config = quant_config + assert self.quant_config.zero_point + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + num_groups = input_size_per_partition // group_size + + qzeros = PackedvLLMParameter( + data=torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + scales = GroupQuantScaleParameter( + data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False) + packed_weight = layer.qweight.data + packed_zeros = layer.qzeros.data + group_num = packed_zeros.size(0) + bits = self.quant_config.weight_bits + pack_factor = int(self.quant_config.pack_factor) + input_size, packed_output_size = packed_weight.size() + output_size = packed_output_size * pack_factor + isa_hint = _get_isa_hint(layer.scales.dtype) + layer.isa_hint = isa_hint + + interleave_map = (0, 4, 1, 5, 2, 6, 3, 7) + weight = unpack_cols( + packed_weight, + bits, + input_size, + output_size, + ) + zeros = unpack_cols( + packed_zeros, + bits, + group_num, + output_size, + ) + weight = ( + weight.view(input_size, -1, pack_factor)[:, :, interleave_map] + .reshape(input_size, output_size) + .contiguous() + ) + zeros = ( + zeros.view(group_num, -1, pack_factor)[:, :, interleave_map] + .reshape(group_num, output_size) + .contiguous() + ) + + zeros = pack_cols(zeros, bits, group_num, output_size).contiguous() + # make 16 output channel as a block and transpose to + # the make the block contigous + weight = pack_cols(weight, bits, input_size, output_size) + weight = ( + weight.view(input_size, -1, 16 // pack_factor) + .permute(1, 0, 2) + .reshape(-1, input_size * 16 // pack_factor) + .contiguous() + ) + layer.qweight.data = weight + layer.qzeros.data = zeros + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + x = cpu_gemm_wna16( + input=x, + q_weight=layer.qweight, + scales=layer.scales, + zeros=layer.qzeros, + g_idx=None, + bias=bias, + pack_factor=8, + isa_hint=layer.isa_hint, + ) + return x + + +def _get_isa_hint(dtype: torch.dtype) -> str: + supports_amx = torch._C._cpu._is_amx_tile_supported() + if supports_amx and dtype in (torch.bfloat16,): + return "amx" + else: + return "vec" diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 5ca9167fa..22c4bae04 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -134,7 +134,7 @@ class IPEXConfig(QuantizationConfig): def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: - if not current_platform.is_cpu() and not current_platform.is_xpu(): + if not current_platform.is_xpu(): return None quant_method = hf_quant_cfg.get("quant_method", "").lower()