@@ -21,9 +21,18 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
static constexpr auto i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
// To match the rounding mode of CUDA, we use nearbyint.
|
||||
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
||||
// If that changes in the future, we may need to set the rounding mode
|
||||
// explicitly, either at runtime or compile time.
|
||||
float dst = std::nearbyint(x);
|
||||
|
||||
// Replace std::clamp due to hip-clang issues
|
||||
// saturate
|
||||
// See https://github.com/pytorch/pytorch/issues/127666
|
||||
// See https://github.com/llvm/llvm-project/issues/95183
|
||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
||||
// dst = std::clamp(dst, i8_min, i8_max);
|
||||
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
@@ -36,16 +45,26 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
|
||||
static inline __device__ int32_t float_to_int32_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
// int32_max is not exactly representable as float.
|
||||
// Therefore, we need to be careful and manually return int32_max on overflow.
|
||||
// For symmetry, we also do the same for int32_min, even though it is exactly
|
||||
// representable as float and the conversion should be exact.
|
||||
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
|
||||
static constexpr auto i32_min_f = static_cast<float>(i32_min);
|
||||
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
|
||||
static constexpr auto i32_max_f = static_cast<float>(i32_max);
|
||||
|
||||
// To match the rounding mode of CUDA, we use nearbyint.
|
||||
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
||||
// If that changes in the future, we may need to set the rounding mode
|
||||
// explicitly, either at runtime or compile time.
|
||||
float dst = std::nearbyint(x);
|
||||
|
||||
// saturate on the higher end.
|
||||
if (dst >= i32_max_f) {
|
||||
return i32_max;
|
||||
}
|
||||
// saturate on the lower end.
|
||||
if (dst <= i32_min_f) {
|
||||
return i32_min;
|
||||
}
|
||||
@@ -66,7 +85,12 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
||||
static constexpr auto i8_max =
|
||||
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
// Replace std::clamp due to hip-clang issues
|
||||
// saturate
|
||||
// See https://github.com/pytorch/pytorch/issues/127666
|
||||
// See https://github.com/llvm/llvm-project/issues/95183
|
||||
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
|
||||
// Arch/gcc14. The following replaces std::clamp usage with similar logic
|
||||
// int32_t dst = std::clamp(x, i8_min, i8_max);
|
||||
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
@@ -88,6 +112,7 @@ __global__ void static_scaled_int8_quant_kernel(
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const float scale = *scale_ptr;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
@@ -109,6 +134,7 @@ __global__ void static_scaled_int8_azp_quant_kernel(
|
||||
const azp_t azp = *azp_ptr;
|
||||
const float inv_s = 1.0f / scale;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
@@ -128,9 +154,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
vectorize_read_with_alignment<16>(
|
||||
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||
@@ -172,6 +200,7 @@ struct MinMax {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// merge two MinMax objects
|
||||
__host__ __device__ MinMax& operator&=(const MinMax& other) {
|
||||
min = fminf(min, other.min);
|
||||
max = fmaxf(max, other.max);
|
||||
@@ -194,6 +223,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
@@ -218,7 +248,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
__shared__ azp_t azp_sh;
|
||||
if (tid == 0) {
|
||||
float s = (mm.max - mm.min) / 255.f;
|
||||
float zp = nearbyintf(-128.f - mm.min / s);
|
||||
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
|
||||
scale_sh = s;
|
||||
azp_sh = azp_t(zp);
|
||||
scale_out[blockIdx.x] = s;
|
||||
|
||||
Reference in New Issue
Block a user