diff --git a/csrc/quantization/8bit/int8/scaled_quant.cu b/csrc/quantization/8bit/int8/scaled_quant.cu index 108e723d6..09c7741fc 100644 --- a/csrc/quantization/8bit/int8/scaled_quant.cu +++ b/csrc/quantization/8bit/int8/scaled_quant.cu @@ -21,9 +21,18 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static constexpr auto i8_max = static_cast(std::numeric_limits::max()); + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); - // Replace std::clamp due to hip-clang issues + // saturate + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // dst = std::clamp(dst, i8_min, i8_max); dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else @@ -36,16 +45,26 @@ static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int32_t float_to_int32_rn(float x) { #ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. static constexpr auto i32_min = std::numeric_limits::min(); static constexpr auto i32_min_f = static_cast(i32_min); static constexpr auto i32_max = std::numeric_limits::max(); static constexpr auto i32_max_f = static_cast(i32_max); + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate on the higher end. if (dst >= i32_max_f) { return i32_max; } + // saturate on the lower end. if (dst <= i32_min_f) { return i32_min; } @@ -66,7 +85,12 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // Replace std::clamp due to hip-clang issues + // saturate + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // int32_t dst = std::clamp(x, i8_min, i8_max); int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x; return static_cast(dst); #else @@ -88,6 +112,7 @@ __global__ void static_scaled_int8_quant_kernel( const int64_t token_idx = blockIdx.x; const float scale = *scale_ptr; + // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; @@ -109,6 +134,7 @@ __global__ void static_scaled_int8_azp_quant_kernel( const azp_t azp = *azp_ptr; const float inv_s = 1.0f / scale; + // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; @@ -128,9 +154,11 @@ __global__ void dynamic_scaled_int8_quant_kernel( const int stride = blockDim.x; const int64_t token_idx = blockIdx.x; + // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; + // calculate for absmax float thread_max = 0.f; vectorize_read_with_alignment<16>( row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { @@ -172,6 +200,7 @@ struct MinMax { return *this; } + // merge two MinMax objects __host__ __device__ MinMax& operator&=(const MinMax& other) { min = fminf(min, other.min); max = fmaxf(max, other.max); @@ -194,6 +223,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const int stride = blockDim.x; const int64_t token_idx = blockIdx.x; + // Must be performed using 64-bit math to avoid integer overflow. const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; @@ -218,7 +248,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( __shared__ azp_t azp_sh; if (tid == 0) { float s = (mm.max - mm.min) / 255.f; - float zp = nearbyintf(-128.f - mm.min / s); + float zp = nearbyintf(-128.f - mm.min / s); // round-to-even scale_sh = s; azp_sh = azp_t(zp); scale_out[blockIdx.x] = s;