diff --git a/CMakeLists.txt b/CMakeLists.txt index e2cc0ccde..7625590e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -243,7 +243,7 @@ set(VLLM_EXT_SRC "csrc/sampler.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" + "csrc/quantization/8bit/int8/scaled_quant.cu" "csrc/quantization/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" @@ -297,7 +297,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/attention/mla/cutlass_mla_entry.cu" - "csrc/quantization/fp8/per_token_group_quant.cu") + "csrc/quantization/8bit/fp8/per_token_group_quant.cu" + "csrc/quantization/8bit/int8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/8bit/fp8/per_token_group_quant.cu similarity index 98% rename from csrc/quantization/fp8/per_token_group_quant.cu rename to csrc/quantization/8bit/fp8/per_token_group_quant.cu index f5b40e35b..39c606d11 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/8bit/fp8/per_token_group_quant.cu @@ -8,9 +8,9 @@ #include -#include "../vectorization.cuh" -#include "../vectorization_utils.cuh" -#include "../../dispatch_utils.h" +#include "../../vectorization.cuh" +#include "../../vectorization_utils.cuh" +#include "../../../dispatch_utils.h" __device__ __forceinline__ float GroupReduceMax(float val, const int tid) { unsigned mask = 0xffff; @@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input, double fp8_max, bool scale_ue8m0) { per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); -} +} \ No newline at end of file diff --git a/csrc/quantization/8bit/int8/per_token_group_quant.cu b/csrc/quantization/8bit/int8/per_token_group_quant.cu new file mode 100644 index 000000000..ab6f52eaf --- /dev/null +++ b/csrc/quantization/8bit/int8/per_token_group_quant.cu @@ -0,0 +1,12 @@ +#include +#include + +#include "../per_token_group_quant_8bit.h" + +void per_token_group_quant_int8(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double int8_min, double int8_max) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + int8_min, int8_max); +} \ No newline at end of file diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/8bit/int8/scaled_quant.cu similarity index 79% rename from csrc/quantization/compressed_tensors/int8_quant_kernels.cu rename to csrc/quantization/8bit/int8/scaled_quant.cu index d8369108d..108e723d6 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/8bit/int8/scaled_quant.cu @@ -1,14 +1,10 @@ #include #include -#ifndef USE_ROCM - #include "../per_token_group_quant_8bit.h" -#endif - #include -#include "../../dispatch_utils.h" -#include "../vectorization_utils.cuh" +#include "../../../dispatch_utils.h" +#include "../../vectorization_utils.cuh" #ifndef USE_ROCM #include @@ -25,19 +21,9 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // To match the rounding mode of CUDA, we use nearbyint. - // It uses the current rounding mode, which is always FE_TONEAREST on HIP. - // If that changes in the future, we may need to set the rounding mode - // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); - // saturate - - // See https://github.com/pytorch/pytorch/issues/127666 - // See https://github.com/llvm/llvm-project/issues/95183 - // hip-clang std::clamp __glibcxx_assert_fail host function when building on - // Arch/gcc14. The following replaces std::clamp usage with similar logic - // dst = std::clamp(dst, i8_min, i8_max); + // Replace std::clamp due to hip-clang issues dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else @@ -50,26 +36,16 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int32_t float_to_int32_rn(float x) { #ifdef USE_ROCM - // int32_max is not exactly representable as float. - // Therefore, we need to be careful and manually return int32_max on overflow. - // For symmetry, we also do the same for int32_min, even though it is exactly - // representable as float and the conversion should be exact. static constexpr auto i32_min = std::numeric_limits::min(); static constexpr auto i32_min_f = static_cast(i32_min); static constexpr auto i32_max = std::numeric_limits::max(); static constexpr auto i32_max_f = static_cast(i32_max); - // To match the rounding mode of CUDA, we use nearbyint. - // It uses the current rounding mode, which is always FE_TONEAREST on HIP. - // If that changes in the future, we may need to set the rounding mode - // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); - // saturate on the higher end. if (dst >= i32_max_f) { return i32_max; } - // saturate on the lower end. if (dst <= i32_min_f) { return i32_min; } @@ -90,13 +66,7 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // saturate - - // See https://github.com/pytorch/pytorch/issues/127666 - // See https://github.com/llvm/llvm-project/issues/95183 - // hip-clang std::clamp __glibcxx_assert_fail host function when building on - // Arch/gcc14. The following replaces std::clamp usage with similar logic - // int32_t dst = std::clamp(x, i8_min, i8_max); + // Replace std::clamp due to hip-clang issues int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x; return static_cast(dst); #else @@ -118,7 +88,6 @@ __global__ void static_scaled_int8_quant_kernel( const int64_t token_idx = blockIdx.x; const float scale = *scale_ptr; - // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; @@ -140,7 +109,6 @@ __global__ void static_scaled_int8_azp_quant_kernel( const azp_t azp = *azp_ptr; const float inv_s = 1.0f / scale; - // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; @@ -160,11 +128,9 @@ __global__ void dynamic_scaled_int8_quant_kernel( const int stride = blockDim.x; const int64_t token_idx = blockIdx.x; - // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; - // calculate for absmax float thread_max = 0.f; vectorize_read_with_alignment<16>( row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { @@ -183,7 +149,6 @@ __global__ void dynamic_scaled_int8_quant_kernel( float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -201,14 +166,12 @@ struct MinMax { __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} - // add a value to the MinMax __host__ __device__ MinMax& operator+=(float v) { min = fminf(min, v); max = fmaxf(max, v); return *this; } - // merge two MinMax objects __host__ __device__ MinMax& operator&=(const MinMax& other) { min = fminf(min, other.min); max = fmaxf(max, other.max); @@ -231,11 +194,9 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const int stride = blockDim.x; const int64_t token_idx = blockIdx.x; - // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; - // 1. calculate min & max MinMax thread_mm; vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { @@ -257,7 +218,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( __shared__ azp_t azp_sh; if (tid == 0) { float s = (mm.max - mm.min) / 255.f; - float zp = nearbyintf(-128.f - mm.min / s); // round-to-even + float zp = nearbyintf(-128.f - mm.min / s); scale_sh = s; azp_sh = azp_t(zp); scale_out[blockIdx.x] = s; @@ -268,7 +229,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const float inv_s = 1.f / scale_sh; const azp_t azp = azp_sh; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -339,14 +299,4 @@ void dynamic_scaled_int8_quant( hidden_size); } }); -} - -#ifndef USE_ROCM -void per_token_group_quant_int8(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double int8_min, double int8_max) { - per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, - int8_min, int8_max); -} -#endif +} \ No newline at end of file diff --git a/csrc/quantization/per_token_group_quant_8bit.h b/csrc/quantization/8bit/per_token_group_quant_8bit.h similarity index 84% rename from csrc/quantization/per_token_group_quant_8bit.h rename to csrc/quantization/8bit/per_token_group_quant_8bit.h index 537b61bc4..25d4ecd11 100644 --- a/csrc/quantization/per_token_group_quant_8bit.h +++ b/csrc/quantization/8bit/per_token_group_quant_8bit.h @@ -1,7 +1,6 @@ #pragma once #include -// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders // 8-bit per-token-group quantization helper used by both FP8 and INT8 void per_token_group_quant_8bit(const torch::Tensor& input, torch::Tensor& output_q,