From f09daea261ee98340512e7c7a5fce09db6f8ab72 Mon Sep 17 00:00:00 2001 From: Yintong Lu Date: Tue, 31 Mar 2026 15:27:37 +0800 Subject: [PATCH] [CPU] Support int8 compute mode in CPU AWQ (#35697) Signed-off-by: Yintong Lu --- .buildkite/hardware_tests/cpu.yaml | 4 +- cmake/cpu_extension.cmake | 1 + csrc/cpu/sgl-kernels/common.h | 8 + csrc/cpu/sgl-kernels/gemm.h | 39 +- csrc/cpu/sgl-kernels/gemm_int4.cpp | 755 ++++++++++++++++++ csrc/cpu/torch_bindings.cpp | 20 + tests/kernels/test_awq_int4_to_int8.py | 281 +++++++ vllm/_custom_ops.py | 32 + vllm/envs.py | 3 + .../layers/quantization/cpu_wna16.py | 65 +- 10 files changed, 1197 insertions(+), 11 deletions(-) create mode 100644 csrc/cpu/sgl-kernels/gemm_int4.cpp create mode 100644 tests/kernels/test_awq_int4_to_int8.py diff --git a/.buildkite/hardware_tests/cpu.yaml b/.buildkite/hardware_tests/cpu.yaml index acca2b368..e466e2a52 100644 --- a/.buildkite/hardware_tests/cpu.yaml +++ b/.buildkite/hardware_tests/cpu.yaml @@ -13,12 +13,14 @@ steps: - tests/kernels/attention/test_cpu_attn.py - tests/kernels/moe/test_cpu_fused_moe.py - tests/kernels/test_onednn.py + - tests/kernels/test_awq_int4_to_int8.py commands: - | bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m " pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py - pytest -x -v -s tests/kernels/test_onednn.py" + pytest -x -v -s tests/kernels/test_onednn.py + pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py" - label: CPU-Compatibility Tests depends_on: [] diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 8d74d6d5d..1b3f0d5ad 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -373,6 +373,7 @@ if (ENABLE_X86_ISA) "csrc/cpu/sgl-kernels/gemm.cpp" "csrc/cpu/sgl-kernels/gemm_int8.cpp" "csrc/cpu/sgl-kernels/gemm_fp8.cpp" + "csrc/cpu/sgl-kernels/gemm_int4.cpp" "csrc/cpu/sgl-kernels/moe.cpp" "csrc/cpu/sgl-kernels/moe_int8.cpp" "csrc/cpu/sgl-kernels/moe_fp8.cpp") diff --git a/csrc/cpu/sgl-kernels/common.h b/csrc/cpu/sgl-kernels/common.h index b96037e82..31be725fa 100644 --- a/csrc/cpu/sgl-kernels/common.h +++ b/csrc/cpu/sgl-kernels/common.h @@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) { #endif } +inline int get_thread_num() { +#if defined(_OPENMP) + return omp_get_thread_num(); +#else + return 0; +#endif +} + // for 1d parallel, use `actual_nth` // for 2d parallel, use even nths, e.g. 43->42 int inline adjust_num_threads(int m) { diff --git a/csrc/cpu/sgl-kernels/gemm.h b/csrc/cpu/sgl-kernels/gemm.h index fba567332..aa78c8807 100644 --- a/csrc/cpu/sgl-kernels/gemm.h +++ b/csrc/cpu/sgl-kernels/gemm.h @@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; } template inline bool can_use_brgemm(int M); template <> inline bool can_use_brgemm(int M) { return M > 4; } template <> inline bool can_use_brgemm(int M) { return true; } -// TODO: add u8s8 brgemm, this requires PyTorch 2.7 -template <> inline bool can_use_brgemm(int M) { return false; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } template <> inline bool can_use_brgemm(int M) { return M > 4; } template <> inline bool can_use_brgemm(int M) { return M > 4; } @@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { return use_int8_w8a8 ? K + sizeof(int32_t) : K; } -// pack weight to vnni format +inline int64_t get_4bit_block_k_size(int64_t group_size) { + return group_size > 128 ? 128 : group_size; +} + +// pack weight into vnni format at::Tensor convert_weight_packed(at::Tensor& weight); +// pack weight to vnni format for int4 (adapted from sglang) +std::tuple +convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales); + // moe implementations for int8 w8a8 template void fused_experts_int8_kernel_impl( @@ -233,6 +241,31 @@ void tinygemm_kernel( int64_t strideBs, bool brg); +// int4 scaled GEMM (adapted from sglang) +at::Tensor int4_scaled_mm_cpu( + at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional bias); + +// int4 tinygemm kernel interface(adapted from sglang) +template +void tinygemm_kernel( + scalar_t* C, + float* C_temp, + const uint8_t* A, + const float* scales_a, + const int32_t* qzeros_a, + const uint8_t* B, + const float* scales_b, + const int8_t* qzeros_b, + const int32_t* compensation, + int8_t* dqB_tmp, + int64_t M, + int64_t K, + int64_t lda, + int64_t ldc_f, + int64_t ldc_s, + bool store_out, + bool use_brgemm); + // TODO: debug print, remove me later inline void print_16x32i(const __m512i x) { int32_t a[16]; diff --git a/csrc/cpu/sgl-kernels/gemm_int4.cpp b/csrc/cpu/sgl-kernels/gemm_int4.cpp new file mode 100644 index 000000000..4a04c5066 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_int4.cpp @@ -0,0 +1,755 @@ +// SPDX-License-Identifier: Apache-2.0 +// Adapted from sgl-project/sglang +// https://github.com/sgl-project/sglang/pull/8226 + +#include + +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +#define BLOCK_N block_size_n() +#define BLOCK_M 128 + +template +struct ActDtype; +template <> +struct ActDtype { + using type = int8_t; +}; +template <> +struct ActDtype { + using type = uint8_t; +}; + +struct alignas(32) m256i_wrapper { + __m256i data; +}; + +#if defined(CPU_CAPABILITY_AVX512) +inline std::array load_zps_4vnni( + const int8_t* __restrict__ zps) { + __m256i vzps_low = _mm256_set1_epi64x(*reinterpret_cast(zps)); + __m256i vzps_high = + _mm256_set1_epi64x(*reinterpret_cast(zps + 8)); + __m256i shuffle_mask = + _mm256_set_epi8(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, + 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0); + vzps_low = _mm256_shuffle_epi8(vzps_low, shuffle_mask); + vzps_high = _mm256_shuffle_epi8(vzps_high, shuffle_mask); + m256i_wrapper vzps_low_wp, vzps_high_wp; + vzps_low_wp.data = vzps_low; + vzps_high_wp.data = vzps_high; + return {vzps_low_wp, vzps_high_wp}; +} + +inline std::array load_uint4_as_int8( + const uint8_t* __restrict__ qB) { + __m256i packed = _mm256_loadu_si256(reinterpret_cast(qB)); + const __m256i low_mask = _mm256_set1_epi8(0x0f); + __m256i high = _mm256_srli_epi16(packed, 4); + high = _mm256_and_si256(high, low_mask); + __m256i low = _mm256_and_si256(packed, low_mask); + m256i_wrapper low_wp, high_wp; + low_wp.data = low; + high_wp.data = high; + return {low_wp, high_wp}; +} + +template +void _dequant_weight_zp_only(const uint8_t* __restrict__ B, int8_t* dqB, + const int8_t* __restrict__ qzeros, int64_t K) { + #pragma GCC unroll 2 + for (int n = 0; n < N; n += 16) { + auto [zps_low_wp, zps_high_wp] = load_zps_4vnni(&qzeros[n]); + auto zps_low = zps_low_wp.data; + auto zps_high = zps_high_wp.data; + for (int k = 0; k < K; k += 4) { + auto [vb_low_wp, vb_high_wp] = + load_uint4_as_int8(B + ldb * k + n / 2 * 4); + auto vb_low = vb_low_wp.data; + auto vb_high = vb_high_wp.data; + vb_high = _mm256_sub_epi8(vb_high, zps_high); + vb_low = _mm256_sub_epi8(vb_low, zps_low); + _mm256_storeu_si256(reinterpret_cast<__m256i_u*>(dqB + N * k + n * 4), + vb_low); + _mm256_storeu_si256( + reinterpret_cast<__m256i_u*>(dqB + N * k + (n + 8) * 4), vb_high); + } + } +} + +template +void _dequant_and_store(float* __restrict__ output, + const int32_t* __restrict__ input, + const float* __restrict__ scale_a, + const int32_t* __restrict__ zp_a, + const float* __restrict__ scale_b, + const int32_t* __restrict__ comp_b, int M, int ldi, + int ldo, int ldsa = 1) { + for (int m = 0; m < M; ++m) { + float a_scale = *(scale_a + m * ldsa); + __m512 va_scale = _mm512_set1_ps(a_scale); + int32_t a_zp; + __m512i va_zp; + if constexpr (!sym_quant_act) { + a_zp = *(zp_a + m * ldsa); + va_zp = _mm512_set1_epi32(a_zp); + } + int n = 0; + #pragma GCC unroll 2 + for (; n < N; n += 16) { + __m512i vc = _mm512_loadu_si512(input + m * ldi + n); + if constexpr (!sym_quant_act) { + __m512i vb_comp = _mm512_loadu_si512(comp_b + n); + vc = _mm512_sub_epi32(vc, _mm512_mullo_epi32(vb_comp, va_zp)); + } + __m512 vc_f = _mm512_cvtepi32_ps(vc); + __m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale); + __m512 vb_s = _mm512_loadu_ps(scale_b + n); + vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s); + if constexpr (accum) { + __m512 vo = _mm512_loadu_ps(output + m * ldo + n); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul)); + } else { + _mm512_storeu_ps(output + m * ldo + n, vc_f_mul); + } + } + for (; n < N; ++n) { + float dq_val; + if constexpr (sym_quant_act) { + dq_val = (float)input[m * ldi + n] * a_scale * scale_b[n]; + } else { + dq_val = (float)(input[m * ldi + n] - a_zp * comp_b[n]) * a_scale * + scale_b[n]; + } + if constexpr (accum) { + output[m * ldo + n] += dq_val; + } else { + output[m * ldo + n] = dq_val; + } + } + } +} + +#else +template +void _dequant_weight_zp_only(const uint8_t* B, int8_t* dqB, + const int8_t* qzeros, int64_t K) { + for (int k = 0; k < K; ++k) { + for (int n = 0; n < N / 2; ++n) { + int32_t b = (int32_t)B[k * ldb + n]; + dqB[k * N + n * 2] = (b & 0xf) - qzeros[n]; + dqB[k * N + n * 2 + 1] = (b >> 4) - qzeros[n]; + } + } +} +#endif + +#if defined(CPU_CAPABILITY_AVX512) +inline __m512i combine_m256i(__m256i a, __m256i b) { + __m512i c = _mm512_castsi256_si512(a); + return _mm512_inserti64x4(c, b, 1); +} + +inline __m512i combine_m256i(std::array two_256) { + return combine_m256i(two_256[0].data, two_256[1].data); +} + +static inline __m512i _mm512_sign_epi8(__m512i a, __m512i b) { + __m512i zero = _mm512_setzero_si512(); + __mmask64 blt0 = _mm512_movepi8_mask(b); + return _mm512_mask_sub_epi8(a, blt0, zero, a); +} + +template +void _dequant_gemm_accum_small_M(float* __restrict__ C, const uint8_t* A, + const float* scales_a, const int32_t* qzeros_a, + const uint8_t* B, const float* scales_b, + const int8_t* qzeros_b, int64_t K, int64_t lda, + int64_t ldc) { + constexpr int COLS = N / 16; + __m512i ones = _mm512_set1_epi8(1); + __m512i va; + __m512i vb[COLS]; + __m512i vc[M * COLS]; + __m512 vscales[COLS]; + __m512i vzps[COLS]; + __m512i vcompensate[COLS]; + + Unroll{}([&](auto i) { + vscales[i] = _mm512_loadu_ps(scales_b + i * 16); + vzps[i] = combine_m256i(load_zps_4vnni(qzeros_b + i * 16)); + if constexpr (!sym_quant_act) { + vcompensate[i] = _mm512_setzero_epi32(); + } + }); + Unroll{}([&](auto i) { vc[i] = _mm512_setzero_epi32(); }); + + auto compute = [&](auto i, int k) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(*(int32_t*)(A + row * lda + k)); + } + + if constexpr (row == 0) { + int B_offset = k * ldb + col * 16 * 2; + vb[col] = combine_m256i(load_uint4_as_int8(B + B_offset)); + vb[col] = _mm512_sub_epi8(vb[col], vzps[col]); + if constexpr (!sym_quant_act) { + vcompensate[col] = _mm512_dpbusd_epi32(vcompensate[col], ones, vb[col]); + } + _mm_prefetch(B + B_offset + 128 * ldb, _MM_HINT_T0); + } + if constexpr (sym_quant_act) { + auto vsb = _mm512_sign_epi8(vb[col], va); + auto vabsa = _mm512_sign_epi8(va, va); + vc[i] = _mm512_dpbusds_epi32(vc[i], vabsa, vsb); + } else { + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + } + }; + + constexpr const int unroll = 4; + int k = 0; + for (; k < K / 4 / unroll; k++) { + Unroll{}( + [&](auto i) { Unroll{}(compute, 4 * (k * unroll + i)); }); + } + k *= 4 * unroll; + for (; k < K; k += 4) { + Unroll{}(compute, k); + } + + auto store = [&](auto i) { + constexpr const int row = i / COLS; + constexpr const int col = i % COLS; + __m512 vc_float; + if constexpr (!sym_quant_act) { + vc[i] = _mm512_sub_epi32( + vc[i], _mm512_mullo_epi32(vcompensate[col], + _mm512_set1_epi32(*(qzeros_a + row)))); + } + vc_float = _mm512_cvtepi32_ps(vc[i]); + vc_float = _mm512_mul_ps(vc_float, _mm512_set1_ps(*(scales_a + row))); + + vc_float = _mm512_mul_ps(vc_float, vscales[col]); + auto vc_old = _mm512_loadu_ps(C + row * ldc + col * 16); + vc_float = _mm512_add_ps(vc_float, vc_old); + _mm512_storeu_ps(C + row * ldc + col * 16, vc_float); + }; + Unroll{}(store); +} + + #define CALL_DEQUANT_GEMM_ACCUM_SMALL_M(M) \ + _dequant_gemm_accum_small_M( \ + C, A, scales_a, qzeros_a, B, scales_b, qzeros_b, K, lda, ldc); +#endif + +template +void _dequant_gemm_accum(float* C, const uint8_t* A, const float* scales_a, + const int32_t* qzeros_a, const uint8_t* B, + const float* scales_b, const int8_t* qzeros_b, + const int32_t* compensation, int8_t* dqB, int64_t M, + int64_t K, int64_t lda, int64_t ldc, bool use_brgemm) { +#if defined(CPU_CAPABILITY_AVX512) + if (!use_brgemm) { + switch (M) { + case 1: + CALL_DEQUANT_GEMM_ACCUM_SMALL_M(1); + break; + case 2: + CALL_DEQUANT_GEMM_ACCUM_SMALL_M(2); + break; + case 3: + CALL_DEQUANT_GEMM_ACCUM_SMALL_M(3); + break; + case 4: + CALL_DEQUANT_GEMM_ACCUM_SMALL_M(4); + break; + default: + TORCH_CHECK(false, "tinygemm_kernel: unexpected M for AVX path!"); + } + return; + } + + _dequant_weight_zp_only(B, dqB, qzeros_b, K); + using Tin = typename ActDtype::type; + Tin* A_ptr = (Tin*)A; + if (use_brgemm) { + int32_t C_i32[M * N]; + at::native::cpublas::brgemm(M, N, K, lda, N /*ldb*/, N /*ldc*/, + false /* add_C */, A_ptr, dqB, C_i32, + true /* is_vnni */); + _mm_prefetch(B + N * K / 2, _MM_HINT_T0); + _mm_prefetch(A + K, _MM_HINT_T0); + _dequant_and_store(C, C_i32, scales_a, qzeros_a, + scales_b, compensation, M, + N /*ldi*/, ldc, 1 /*ldsa*/); + } else +#endif + { + TORCH_CHECK(false, "tinygemm_kernel: scalar path not implemented!"); + } +} + +template +inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) { + if (bias_ptr) { + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) + #pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 bias_vec = _mm512_loadu_ps(bias_ptr + j); + _mm512_storeu_ps(y_buf + i * N + j, bias_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = bias_ptr[j]; + } + } + } else { + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) + #pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 zero_vec = _mm512_setzero_ps(); + _mm512_storeu_ps(y_buf + i * N + j, zero_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = 0; + } + } + } +} + +template +inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, + int64_t lda) { + for (int i = 0; i < m; ++i) { + int j = 0; + if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) + #pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = y_buf[i * N + j]; + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) + #pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), + y_bf16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]); + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) + #pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), + y_fp16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]); + } + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + } +} + +void fill_val_stub(int32_t* __restrict__ output, int32_t value, int64_t size) { + using iVec = at::vec::Vectorized; + constexpr int VecSize = iVec::size(); + const iVec fill_val_vec = iVec(value); + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - VecSize; d += VecSize) { + fill_val_vec.store(output + d); + } + for (; d < size; ++d) { + output[d] = value; + } +} + +template +void _da8w4_linear_impl( + act_dtype* __restrict__ input, const float* __restrict__ input_scales, + const int32_t* __restrict__ input_qzeros, + const uint8_t* __restrict__ weight, const float* __restrict__ weight_scales, + const int8_t* __restrict__ weight_qzeros, const float* __restrict__ bias, + out_dtype* __restrict__ output, float* __restrict__ output_temp, + int8_t* __restrict__ dequant_weight_temp, int64_t M, int64_t N, int64_t K, + int64_t num_groups) { + const bool use_brgemm = can_use_brgemm(M); + int64_t block_m = [&]() -> long { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 64; + } else { + return 128; + } + }(); + int64_t Mc = div_up(M, block_m); + bool parallel_on_M = M > 128; + int64_t Nc = N / BLOCK_N; + int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc; + int64_t group_size = div_up(K, num_groups); + int64_t _block_k = get_4bit_block_k_size(group_size); + int64_t Kc = K / _block_k; + int64_t block_per_group = group_size / _block_k; + + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + int tid = get_thread_num(); + float* C_tmp = output_temp + tid * block_m * BLOCK_N; + int8_t* dqB_tmp = dequant_weight_temp + tid * _block_k * BLOCK_N; + for (const auto i : c10::irange(begin, end)) { + int64_t mc = parallel_on_M ? i / Nc : 0; + int64_t nc = parallel_on_M ? i % Nc : i; + int64_t mc_end = parallel_on_M ? mc + 1 : Mc; + + for (int mci = mc; mci < mc_end; ++mci) { + int64_t m_size = + mci * block_m + block_m > M ? M - mci * block_m : block_m; + auto bias_data = bias ? bias + nc * BLOCK_N : nullptr; + copy_bias(bias_data, C_tmp, m_size); + for (int kci = 0; kci < Kc; ++kci) { + int32_t* compensation_ptr = + sym_quant_act + ? nullptr + : (int32_t*)(void*)(weight + + (nc * Kc + kci) * + (BLOCK_N * + (_block_k / 2 + sizeof(int32_t))) + + _block_k * BLOCK_N / 2); + _dequant_gemm_accum( + /*C*/ C_tmp, + /*A*/ (uint8_t*)input + mci * block_m * K + kci * _block_k, + /*scales_a*/ input_scales + mci * block_m, + /*qzeros_a*/ input_qzeros + mci * block_m, + /*B*/ weight + (nc * Kc + kci) * + (BLOCK_N * (_block_k / 2 + sizeof(int32_t))), + /*scales_b*/ weight_scales + nc * BLOCK_N * num_groups + + kci / block_per_group * BLOCK_N, + /*qzeros_b*/ weight_qzeros + nc * BLOCK_N * num_groups + + kci / block_per_group * BLOCK_N, + /*Bcomp*/ compensation_ptr, + /*dqB_tmp*/ dqB_tmp, + /*M*/ m_size, + /*K*/ _block_k, + /*lda*/ K, + /*ldc*/ BLOCK_N, + /*use_brgemm*/ use_brgemm); + } + store_out(C_tmp, output + mci * block_m * N + nc * BLOCK_N, + m_size, N /*lda*/); + } + } + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); +} + +} // anonymous namespace + +std::tuple +convert_int4_weight_packed_with_compensation(const at::Tensor& weight, + const at::Tensor& scales, + const at::Tensor& qzeros) { + TORCH_CHECK(weight.dim() == 2, + "DA8W4 CPU: Weight should be a 2D tensor for packing"); + TORCH_CHECK( + weight.size(1) % 2 == 0, + "DA8W4 CPU: Weight should have even number of columns for packing"); + + auto new_scales = scales; + auto new_qzeros = qzeros; + if (new_scales.dim() == 1) { + new_scales.unsqueeze_(1); + } + new_scales = new_scales.to(at::kFloat); + if (new_qzeros.dim() == 1) { + new_qzeros.unsqueeze_(1); + } + new_qzeros = new_qzeros.to(at::kChar); + int64_t N = weight.size(0); + int64_t K = weight.size(1); + int64_t G = scales.size(1); + int64_t group_size = K / G; + int64_t _block_k = get_4bit_block_k_size(group_size); + constexpr int block_n = block_size_n(); + int64_t Nc = N / block_n; + int64_t Kc = K / _block_k; + + auto weight_view = weight.view({Nc, block_n, Kc, _block_k}); + at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); + at::Tensor blocked_weight; + at::Tensor blocked_scales = + new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + at::Tensor blocked_qzeros = + new_qzeros.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + auto weight_sub_qzero = weight.view({Nc, block_n, G, -1}).to(at::kInt) - + new_qzeros.view({Nc, block_n, G, -1}); + weight_sub_qzero = weight_sub_qzero.view({Nc, block_n, Kc, _block_k}); + at::Tensor compensation = weight_sub_qzero.sum(-1); + compensation = compensation.permute({0, 2, 1}).contiguous().to(at::kInt); + int64_t buffer_size_nbytes = + _block_k * block_n / 2 + block_n * sizeof(int32_t); + blocked_weight = at::empty({Nc, Kc, buffer_size_nbytes}, weight.options()); + + auto weight_ptr = weight_reordered.data_ptr(); + auto compensation_ptr = compensation.data_ptr(); + auto blocked_weight_ptr = blocked_weight.data_ptr(); + int64_t num_blocks = Nc * Kc; + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + auto in_ptr = weight_ptr + i * _block_k * block_n; + auto out_ptr = + blocked_weight_ptr + i * block_n * (_block_k / 2 + sizeof(int32_t)); + int32_t* comp_in_prt = compensation_ptr + i * block_n; + int32_t* comp_out_prt = + (int32_t*)(void*)(blocked_weight_ptr + + i * block_n * (_block_k / 2 + sizeof(int32_t)) + + _block_k * block_n / 2); + constexpr int n_group_size = 8; + constexpr int vnni_size = 4; + constexpr int n_group = block_n / n_group_size; + for (int nb = 0; nb < n_group; nb += 2) { + for (int k = 0; k < _block_k; k += vnni_size) { + for (int ni = 0; ni < n_group_size; ++ni) { + for (int ki = 0; ki < vnni_size; ++ki) { + int src_idx_1 = nb * n_group_size + ni + (k + ki) * block_n; + int src_idx_2 = (nb + 1) * n_group_size + ni + (k + ki) * block_n; + int dst_idx = (nb / 2 * n_group_size + ni) * vnni_size + + k * block_n / 2 + ki; + uint8_t src_1 = *(in_ptr + src_idx_1); + uint8_t src_2 = *(in_ptr + src_idx_2); + uint8_t dst = (src_1 & 0x0f) | ((src_2 & 0x0f) << 4); + *(out_ptr + dst_idx) = dst; + } + } + } + } + for (int nb = 0; nb < block_n; nb++) { + *(comp_out_prt + nb) = *(comp_in_prt + nb); + } + } + }); + + return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales), + std::move(blocked_qzeros)); +} + +std::tuple autoawq_to_int4pack(at::Tensor qweight, + at::Tensor qzeros) { + auto bitshifts = at::tensor({0, 4, 1, 5, 2, 6, 3, 7}, at::kInt) * 4; + auto qweight_unsq = qweight.unsqueeze(-1); + auto unpacked = at::bitwise_right_shift(qweight_unsq, bitshifts) & 0xF; + auto qweight_final = unpacked.flatten(-2).transpose(-1, -2).to(at::kByte); + + auto qzeros_unsq = qzeros.unsqueeze(-1); + auto qzeros_unpacked = at::bitwise_right_shift(qzeros_unsq, bitshifts) & 0xF; + auto qzeros_final = qzeros_unpacked.flatten(-2).to(at::kByte); + + return std::make_tuple(qweight_final, qzeros_final); +} + +std::tuple convert_weight_packed_scale_zp( + at::Tensor qweight, at::Tensor qzeros, at::Tensor scales) { + auto res = autoawq_to_int4pack(qweight, qzeros); + auto _qweight = std::get<0>(res); + auto _qzeros = std::get<1>(res); + auto _scales = scales; + _qzeros = _qzeros.transpose(-2, -1).contiguous(); + _scales = _scales.transpose(-2, -1).contiguous(); + if (_qweight.dim() == 3) { + int64_t E = _qweight.size(0); + int64_t K = _qweight.size(2); + int64_t G = _scales.size(2); + int64_t group_size = K / G; + int64_t _block_k = get_4bit_block_k_size(group_size); + int64_t block_n = block_size_n(); + int64_t Nc = _qweight.size(1) / block_n; + int64_t Kc = K / _block_k; + int64_t buffer_size_nbytes = + _block_k * block_n / 2 + block_n * sizeof(int32_t); + auto blocked_weight = + at::empty({E, Nc, Kc, buffer_size_nbytes}, _qweight.options()); + auto blocked_scales = + at::empty({E, Nc, G, block_n}, _scales.options()).to(at::kFloat); + auto blocked_qzeros = + at::empty({E, Nc, G, block_n}, _qzeros.options()).to(at::kChar); + for (int i = 0; i < _qweight.size(0); i++) { + auto res_ = convert_int4_weight_packed_with_compensation( + _qweight[i], _scales[i], _qzeros[i]); + blocked_weight[i] = std::get<0>(res_); + blocked_scales[i] = std::get<1>(res_); + blocked_qzeros[i] = std::get<2>(res_); + } + _qweight = blocked_weight; + _scales = blocked_scales; + _qzeros = blocked_qzeros; + } else { + auto res_ = convert_int4_weight_packed_with_compensation(_qweight, _scales, + _qzeros); + _qweight = std::get<0>(res_); + _scales = std::get<1>(res_); + _qzeros = std::get<2>(res_); + } + + return std::make_tuple(_qweight, _qzeros, _scales); +} + +at::Tensor int4_scaled_mm_cpu_with_quant(const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const at::Tensor& weight_qzeros, + const std::optional& bias, + at::ScalarType output_dtype) { + RECORD_FUNCTION("vllm::int4_scaled_mm_cpu_with_quant", + std::vector({input, weight})); + + int64_t M_a = input.size(0); + int64_t K_a = input.size(1); + int64_t lda = input.stride(0); + + const auto st = input.scalar_type(); + TORCH_CHECK( + st == at::kBFloat16 || st == at::kHalf, + "int4_scaled_mm_cpu_with_quant: expect A to be bfloat16 or half."); + + constexpr bool sym_quant_act = false; + using Tin = typename ActDtype::type; + int64_t act_buffer_size = + M_a * K_a + M_a * sizeof(float) + M_a * sizeof(int32_t); + auto act_buffer = + at::empty({act_buffer_size}, input.options().dtype(at::kByte)); + auto Aq_data = act_buffer.data_ptr(); + auto As_data = reinterpret_cast(Aq_data + M_a * K_a); + auto Azp_data = reinterpret_cast(As_data + M_a); + fill_val_stub(Azp_data, 128, M_a); + + auto out_sizes = input.sizes().vec(); + int64_t N = weight_scales.size(0) * weight_scales.size(-1); + out_sizes.back() = N; + auto output = at::empty(out_sizes, input.options()); + int64_t Nc = weight.size(0); + int64_t Kc = weight.size(1); + int64_t _block_k = K_a / Kc; + TORCH_CHECK(N == Nc * BLOCK_N, "DA8W4: weight and input shapes mismatch"); + int64_t num_groups = weight_scales.size(1); + + const uint8_t* b_ptr = weight.data_ptr(); + const float* b_scales_ptr = weight_scales.data_ptr(); + const int8_t* b_qzeros_ptr = weight_qzeros.data_ptr(); + const float* bias_ptr = + bias.has_value() ? bias.value().data_ptr() : nullptr; + int num_threads = at::get_num_threads(); + int64_t temp_buffer_size = num_threads * BLOCK_M * BLOCK_N * sizeof(float) + + num_threads * _block_k * BLOCK_N; + auto c_temp_buffer = + at::empty({temp_buffer_size}, input.options().dtype(at::kChar)); + float* c_temp_ptr = (float*)((void*)(c_temp_buffer.data_ptr())); + int8_t* dqB_temp_ptr = + (int8_t*)((void*)(c_temp_ptr + num_threads * BLOCK_M * BLOCK_N)); + +#define LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act) \ + AT_DISPATCH_FLOATING_TYPES_AND2( \ + at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, \ + "int4_scaled_mm_cpu", [&] { \ + const scalar_t* __restrict__ A_data = input.data_ptr(); \ + scalar_t* __restrict__ c_ptr = output.data_ptr(); \ + at::parallel_for(0, M_a, 0, [&](int64_t begin, int64_t end) { \ + for (int64_t m = begin; m < end; ++m) { \ + quantize_row_int8(Aq_data + m * K_a, As_data[m], \ + A_data + m * lda, K_a); \ + } \ + }); \ + _da8w4_linear_impl( \ + Aq_data, As_data, Azp_data, b_ptr, b_scales_ptr, b_qzeros_ptr, \ + bias_ptr, c_ptr, c_temp_ptr, dqB_temp_ptr, M_a, N, K_a, \ + num_groups); \ + }); + + LAUNCH_DA8W4_LINEAR_WITH_QUANT_IMPL(sym_quant_act); + + return output; +} + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, + const float* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + Vec res = convert_from_float_ext(x0, x1); + res.store(out + d); + } +} + +} // anonymous namespace + +template +void tinygemm_kernel(scalar_t* C, float* C_temp, const uint8_t* A, + const float* scales_a, const int32_t* qzeros_a, + const uint8_t* B, const float* scales_b, + const int8_t* qzeros_b, const int32_t* compensation, + int8_t* dqB_tmp, int64_t M, int64_t K, int64_t lda, + int64_t ldc_f, int64_t ldc_s, bool store_out, + bool use_brgemm) { + _dequant_gemm_accum( + C_temp, A, scales_a, qzeros_a, B, scales_b, qzeros_b, compensation, + dqB_tmp, M, K, lda, ldc_f, use_brgemm); + if (store_out) { + for (int64_t m = 0; m < M; ++m) { + copy_stub(C + m * ldc_s, C_temp + m * ldc_f, BLOCK_N); + } + } +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + TYPE * C, float* C_temp, const uint8_t* A, const float* scales_a, \ + const int32_t* qzeros_a, const uint8_t* B, const float* scales_b, \ + const int8_t* qzeros_b, const int32_t* compensation, int8_t* dqB_tmp, \ + int64_t M, int64_t K, int64_t lda, int64_t ldc_f, int64_t ldc_s, \ + bool store_out, bool use_brgemm) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, + at::Tensor& w_scales, + std::optional bias) { + return int4_scaled_mm_cpu_with_quant(x, w, w_scales, w_zeros, bias, + x.scalar_type()); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 15b254662..a1d7d361d 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); +// Adapted from sglang: INT4 W4A8 kernels +std::tuple convert_weight_packed_scale_zp( + at::Tensor qweight, at::Tensor qzeros, at::Tensor scales); + +at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, + at::Tensor& w_scales, + std::optional bias); + torch::Tensor get_scheduler_metadata( const int64_t num_req, const int64_t num_heads_q, const int64_t num_heads_kv, const int64_t head_dim, @@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); ops.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant); + + // Adapted from sglang: INT4 W4A8 kernels + ops.def( + "convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, " + "Tensor scales) -> (Tensor, Tensor, Tensor)"); + ops.impl("convert_weight_packed_scale_zp", torch::kCPU, + &convert_weight_packed_scale_zp); + + ops.def( + "int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, " + "Tensor(a3!) w_scales, Tensor? bias) -> Tensor"); + ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu); #endif // CPU attention kernels diff --git a/tests/kernels/test_awq_int4_to_int8.py b/tests/kernels/test_awq_int4_to_int8.py new file mode 100644 index 000000000..829c0d227 --- /dev/null +++ b/tests/kernels/test_awq_int4_to_int8.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for AWQ INT4 W4A8 GEMM pipeline (SGLang kernel migration). + +Part 1: Weight packing tests + - convert_weight_packed_scale_zp correctness + +Part 2: INT4 W4A8 GEMM tests + - int4_scaled_mm_cpu correctness w.r.t. float reference + - Bias, 3D input, various shapes + +Part 3: create_weights shapes + +cmd: + VLLM_CPU_INT4_W4A8=1 python -m pytest tests/kernels/test_awq_int4_to_int8.py -v -s +""" + +import numpy as np +import pytest +import torch + +from vllm._custom_ops import _supports_cpu_w4a8_int8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_cols, +) +from vllm.platforms import current_platform + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +requires_cpu_w4a8_int8 = pytest.mark.skipif( + not _supports_cpu_w4a8_int8, + reason="Requires vLLM CPU build with SGLang INT4 W4A8 kernels", +) + + +def make_awq_checkpoint_data(K, N, group_size, seed=42): + """Create synthetic AWQ checkpoint data in packed int32 format. + + Returns: + packed_qweight: [K, N//8] int32 (AWQ interleaved + packed) + packed_qzeros: [num_groups, N//8] int32 (AWQ interleaved + packed) + scales: [num_groups, N] float32 + float_ref: [K, N] float32, reference dequantized weights + weight_int4_orig: [K, N] int32, original int4 values (0-15) + zeros_int4_orig: [num_groups, N] int32, original zero points (0-15) + """ + rng = np.random.RandomState(seed) + num_groups = K // group_size + + weight_int4_orig = torch.from_numpy( + rng.randint(0, 16, size=(K, N)).astype(np.int32) + ) + zeros_int4_orig = torch.from_numpy( + rng.randint(0, 16, size=(num_groups, N)).astype(np.int32) + ) + scales = torch.from_numpy((rng.randn(num_groups, N) * 0.05).astype(np.float32)) + + scales_exp = scales.repeat_interleave(group_size, dim=0) + zeros_exp = zeros_int4_orig.repeat_interleave(group_size, dim=0) + float_ref = (weight_int4_orig.float() - zeros_exp.float()) * scales_exp + + awq_interleave = [0, 2, 4, 6, 1, 3, 5, 7] + weight_interleaved = ( + weight_int4_orig.reshape(-1, 8)[:, awq_interleave].reshape(K, N).contiguous() + ) + packed_qweight = pack_cols(weight_interleaved, 4, K, N) + + zeros_interleaved = ( + zeros_int4_orig.reshape(-1, 8)[:, awq_interleave] + .reshape(num_groups, N) + .contiguous() + ) + packed_qzeros = pack_cols(zeros_interleaved, 4, num_groups, N) + + return ( + packed_qweight, + packed_qzeros, + scales, + float_ref, + weight_int4_orig, + zeros_int4_orig, + ) + + +class TestConvertWeightPackedScaleZp: + """Tests for convert_weight_packed_scale_zp weightpacking.""" + + @requires_cpu_w4a8_int8 + @pytest.mark.parametrize( + "K,N,group_size", + [ + (128, 128, 128), + (256, 256, 128), + (512, 256, 64), + ], + ) + def test_packing_output_shapes(self, K, N, group_size): + """Packed outputs should have expected shapes.""" + (packed_qweight, packed_qzeros, scales, _, _, _) = make_awq_checkpoint_data( + K, N, group_size + ) + + blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp( + packed_qweight, packed_qzeros, scales + ) + + block_n = 32 + Nc = N // block_n + + assert blocked_w.dim() >= 2, ( + f"blocked_w should have >= 2 dims, got {blocked_w.dim()}" + ) + assert blocked_s.size(0) == Nc, ( + f"Expected Nc={Nc} scale blocks, got {blocked_s.size(0)}" + ) + assert blocked_zp.size(0) == Nc, ( + f"Expected Nc={Nc} qzeros blocks, got {blocked_zp.size(0)}" + ) + + print( + f" [PASS] packing shapes K={K}, N={N}, gs={group_size}: " + f"blocked_w={list(blocked_w.shape)}, " + f"blocked_s={list(blocked_s.shape)}, blocked_zp={list(blocked_zp.shape)}" + ) + + +class TestInt4ScaledMmCpu: + """Tests for int4_scaled_mm_cpu GEMM kernel.""" + + @requires_cpu_w4a8_int8 + @pytest.mark.parametrize( + "M,K,N,group_size", + [ + (1, 128, 128, 128), + (4, 256, 256, 128), + (16, 512, 256, 64), + (32, 256, 512, 128), + (64, 512, 512, 128), + ], + ) + def test_gemm_vs_float_reference(self, M, K, N, group_size): + """INT4 W4A8 GEMM should approximate float matmul.""" + (packed_qweight, packed_qzeros, scales, float_ref, _, _) = ( + make_awq_checkpoint_data(K, N, group_size) + ) + + blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp( + packed_qweight, packed_qzeros, scales + ) + + x = torch.randn(M, K, dtype=torch.bfloat16) + out = torch.ops._C.int4_scaled_mm_cpu(x, blocked_w, blocked_zp, blocked_s, None) + + ref_out = torch.mm(x.float(), float_ref) + + abs_diff = (out.float() - ref_out).abs() + mean_abs = abs_diff.mean().item() + pct95 = torch.quantile(abs_diff, 0.95).item() + ref_mag = ref_out.abs().mean().item() + 1e-6 + mean_rel = mean_abs / ref_mag + + assert mean_rel < 0.05, ( + f"Mean relative error {mean_rel:.4f} exceeds 5% threshold" + ) + assert pct95 < ref_mag * 0.15, ( + f"95th-pctile abs_diff {pct95:.4f} exceeds 15% of ref magnitude" + ) + print(f" [PASS] INT4 GEMM correct: M={M}, K={K}, N={N}") + + @requires_cpu_w4a8_int8 + @pytest.mark.parametrize("M", [1, 8, 32]) + def test_gemm_with_bias(self, M): + """INT4 W4A8 GEMM with bias should match reference.""" + K, N, group_size = 256, 128, 128 + (packed_qweight, packed_qzeros, scales, float_ref, _, _) = ( + make_awq_checkpoint_data(K, N, group_size) + ) + + blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp( + packed_qweight, packed_qzeros, scales + ) + + bias = torch.randn(N, dtype=torch.float32) + x = torch.randn(M, K, dtype=torch.bfloat16) + + out = torch.ops._C.int4_scaled_mm_cpu(x, blocked_w, blocked_zp, blocked_s, bias) + + ref_out = torch.mm(x.float(), float_ref) + bias + abs_diff = (out.float() - ref_out).abs() + mean_abs = abs_diff.mean().item() + ref_mag = ref_out.abs().mean().item() + 1e-6 + mean_rel = mean_abs / ref_mag + assert mean_rel < 0.05, ( + f"Mean relative error {mean_rel:.4f} with bias exceeds 5%" + ) + print(f" [PASS] INT4 GEMM with bias: M={M}") + + @requires_cpu_w4a8_int8 + def test_gemm_3d_input(self): + """apply() reshapes 3D input [B, S, K] -> [B*S, K] -> back to 3D.""" + K, N, group_size = 256, 128, 128 + (packed_qweight, packed_qzeros, scales, float_ref, _, _) = ( + make_awq_checkpoint_data(K, N, group_size) + ) + + blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp( + packed_qweight, packed_qzeros, scales + ) + + B, S = 2, 8 + x_3d = torch.randn(B, S, K, dtype=torch.bfloat16) + x_2d = x_3d.reshape(-1, K) + + out_2d = torch.ops._C.int4_scaled_mm_cpu( + x_2d, blocked_w, blocked_zp, blocked_s, None + ) + out_3d = out_2d.reshape(B, S, N) + + ref_out = torch.mm(x_2d.float(), float_ref).reshape(B, S, N) + + assert out_3d.shape == (B, S, N) + abs_diff = (out_3d.float() - ref_out).abs() + mean_abs = abs_diff.mean().item() + ref_mag = ref_out.abs().mean().item() + 1e-6 + mean_rel = mean_abs / ref_mag + + assert mean_rel < 0.05, f"Mean relative error {mean_rel:.4f} for 3D exceeds 5%" + print(f" [PASS] 3D input [{B},{S},{K}] -> output [{B},{S},{N}]") + + @requires_cpu_w4a8_int8 + def test_gemm_fp16_input(self): + """INT4 GEMM should also work with fp16 input.""" + K, N, group_size, M = 256, 256, 128, 8 + (packed_qweight, packed_qzeros, scales, float_ref, _, _) = ( + make_awq_checkpoint_data(K, N, group_size) + ) + + blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp( + packed_qweight, packed_qzeros, scales + ) + + x = torch.randn(M, K, dtype=torch.float16) + out = torch.ops._C.int4_scaled_mm_cpu(x, blocked_w, blocked_zp, blocked_s, None) + + ref_out = torch.mm(x.float(), float_ref) + abs_diff = (out.float() - ref_out).abs() + ref_mag = ref_out.abs().mean().item() + 1e-6 + mean_rel = abs_diff.mean().item() / ref_mag + + assert mean_rel < 0.05, ( + f"Mean relative error {mean_rel:.4f} for fp16 exceeds 5%" + ) + print(f" [PASS] fp16 input M={M}, K={K}, N={N}") + + +class TestCreateWeightsUnchanged: + """Create_weights should still produce correct int4 placeholder shapes.""" + + @pytest.mark.parametrize( + "K,N,group_size", + [ + (128, 128, 128), + (256, 256, 128), + (512, 256, 64), + ], + ) + def test_int4_placeholder_shapes(self, K, N, group_size): + """Verify qweight, qzeros, scales shapes.""" + pack_factor = 8 + num_groups = K // group_size + + qweight = torch.empty(K, N // pack_factor, dtype=torch.int32) + qzeros = torch.empty(num_groups, N // pack_factor, dtype=torch.int32) + scales = torch.empty(num_groups, N, dtype=torch.bfloat16) + + assert qweight.shape == (K, N // pack_factor) + assert qzeros.shape == (num_groups, N // pack_factor) + assert scales.shape == (num_groups, N) + print(f" [PASS] create_weights shapes: K={K}, N={N}, gs={group_size}") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7fef4b71a..ea54aaa95 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2967,6 +2967,38 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): return torch.empty((M, N), dtype=out_dtype) +if hasattr(torch.ops._C, "convert_weight_packed_scale_zp"): + + @register_fake("_C::convert_weight_packed_scale_zp") + def convert_weight_packed_scale_zp_fake( + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return ( + torch.empty_like(qweight), + torch.empty_like(qzeros), + torch.empty_like(scales), + ) + + +if hasattr(torch.ops._C, "int4_scaled_mm_cpu"): + + @register_fake("_C::int4_scaled_mm_cpu") + def int4_scaled_mm_cpu_fake( + x: torch.Tensor, + w: torch.Tensor, + w_zeros: torch.Tensor, + w_scales: torch.Tensor, + bias: torch.Tensor | None, + ) -> torch.Tensor: + N = w_scales.size(0) * w_scales.size(-1) + return torch.empty((x.size(0), N), dtype=x.dtype, device=x.device) + + +_supports_cpu_w4a8_int8 = bool(hasattr(torch.ops._C, "convert_weight_packed_scale_zp")) + + class CPUDNNLGEMMHandler: def __init__(self) -> None: self.handler_tensor: torch.Tensor | None = None diff --git a/vllm/envs.py b/vllm/envs.py index 0bd8f0fec..fb68fc4b1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -52,6 +52,7 @@ if TYPE_CHECKING: VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_SGL_KERNEL: bool = False VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True + VLLM_CPU_INT4_W4A8: bool = True VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" @@ -728,6 +729,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool( int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1")) ), + # (CPU backend only) whether to use SGLang INT4 W4A8 kernels for AWQ. + "VLLM_CPU_INT4_W4A8": lambda: bool(int(os.getenv("VLLM_CPU_INT4_W4A8", "1"))), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. diff --git a/vllm/model_executor/layers/quantization/cpu_wna16.py b/vllm/model_executor/layers/quantization/cpu_wna16.py index 3dba31743..8ec569042 100644 --- a/vllm/model_executor/layers/quantization/cpu_wna16.py +++ b/vllm/model_executor/layers/quantization/cpu_wna16.py @@ -7,9 +7,8 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers import PretrainedConfig -from vllm._custom_ops import ( - cpu_gemm_wna16, -) +import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import ( LinearBase, @@ -230,7 +229,14 @@ class CPUAWQLinearMethod(LinearMethodBase): layer.register_parameter("scales", scales) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False) + layer.use_w4a8 = envs.VLLM_CPU_INT4_W4A8 and torch.cpu._is_amx_tile_supported() + if layer.use_w4a8: + self._process_weights_sglang_int4(layer) + else: + self._process_weights_woq(layer) + + def _process_weights_woq(self, layer: torch.nn.Module) -> None: + """Original WOQ int4 repack path.""" packed_weight = layer.qweight.data packed_zeros = layer.qzeros.data group_num = packed_zeros.size(0) @@ -266,8 +272,6 @@ class CPUAWQLinearMethod(LinearMethodBase): ) zeros = pack_cols(zeros, bits, group_num, output_size).contiguous() - # make 16 output channel as a block and transpose to - # the make the block contiguous weight = pack_cols(weight, bits, input_size, output_size) weight = ( weight.view(input_size, -1, 16 // pack_factor) @@ -278,13 +282,40 @@ class CPUAWQLinearMethod(LinearMethodBase): layer.qweight.data = weight layer.qzeros.data = zeros + def _process_weights_sglang_int4(self, layer: torch.nn.Module) -> None: + """SGLang INT4 W4A8 path: pack int4 weights with VNNI reordering.""" + packed_weight = layer.qweight.data + packed_zeros = layer.qzeros.data + scales = layer.scales.data + blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp( + packed_weight, packed_zeros, scales + ) + + layer.packed_weight = blocked_w + layer.packed_qzeros = blocked_zp + layer.packed_scales = blocked_s + layer.qweight = None + layer.qzeros = None + layer.scales = None + def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - x = cpu_gemm_wna16( + if layer.use_w4a8: + return self._apply_sglang_int4(layer, x, bias) + return self._apply_woq(layer, x, bias) + + def _apply_woq( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + """Original WOQ int4 GEMM path.""" + x = ops.cpu_gemm_wna16( input=x, q_weight=layer.qweight, scales=layer.scales, @@ -296,6 +327,26 @@ class CPUAWQLinearMethod(LinearMethodBase): ) return x + def _apply_sglang_int4( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + """SGLang INT4 W4A8 GEMM path.""" + x_shape = x.shape + x_2d = x.reshape(-1, x_shape[-1]) if len(x_shape) > 2 else x + + out = torch.ops._C.int4_scaled_mm_cpu( + x_2d, + layer.packed_weight, + layer.packed_qzeros, + layer.packed_scales, + bias, + ) + out = out.reshape(x_shape[:-1] + (out.size(-1),)) if len(x_shape) > 2 else out + return out + def _get_isa_hint(dtype: torch.dtype) -> str: supports_amx = torch.cpu._is_amx_tile_supported()