[Perf] Vectorize static / dynamic INT8 quant kernels (#19233)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../vectorization_utils.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#else
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_type.hpp>
|
||||
#endif
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
@@ -103,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename scale_type>
|
||||
template <typename scalar_t, typename scale_t>
|
||||
__global__ void static_scaled_int8_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type const* scale_ptr, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
scale_type const scale = *scale_ptr;
|
||||
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
||||
const scale_t* scale_ptr, const int hidden_size) {
|
||||
const int tid = threadIdx.x;
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const float scale = *scale_ptr;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
|
||||
}
|
||||
vectorize_with_alignment<16>(
|
||||
row_in, row_out, hidden_size, tid, stride,
|
||||
[=] __device__(int8_t& dst, const scalar_t& src) {
|
||||
dst = float_to_int8_rn(static_cast<float>(src) / scale);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_type, typename azp_type>
|
||||
template <typename scalar_t, typename scale_t, typename azp_t>
|
||||
__global__ void static_scaled_int8_azp_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type const* scale_ptr, azp_type const* azp_ptr,
|
||||
const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
scale_type const scale = *scale_ptr;
|
||||
azp_type const azp = *azp_ptr;
|
||||
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
||||
const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) {
|
||||
const int tid = threadIdx.x;
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const float scale = *scale_ptr;
|
||||
const azp_t azp = *azp_ptr;
|
||||
const float inv_s = 1.0f / scale;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
auto const val = static_cast<float>(input[i]);
|
||||
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
|
||||
out[i] = quant_val;
|
||||
}
|
||||
vectorize_with_alignment<16>(
|
||||
row_in, row_out, hidden_size, tid, stride,
|
||||
[=] __device__(int8_t& dst, const scalar_t& src) {
|
||||
const auto v = static_cast<float>(src) * inv_s;
|
||||
dst = int32_to_int8(float_to_int32_rn(v) + azp);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_type>
|
||||
template <typename scalar_t, typename scale_t>
|
||||
__global__ void dynamic_scaled_int8_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type* scale, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
float absmax_val = 0.0f;
|
||||
float const zero = 0.0f;
|
||||
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
||||
scale_t* scale_out, const int hidden_size) {
|
||||
const int tid = threadIdx.x;
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float val = static_cast<float>(input[i]);
|
||||
val = val > zero ? val : -val;
|
||||
absmax_val = val > absmax_val ? val : absmax_val;
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
const auto v = fabsf(static_cast<float>(row_in[i]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
float const block_absmax_val_maybe =
|
||||
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
__shared__ float block_absmax_val;
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||
__shared__ float absmax;
|
||||
if (tid == 0) {
|
||||
block_absmax_val = block_absmax_val_maybe;
|
||||
scale[token_idx] = block_absmax_val / 127.0f;
|
||||
absmax = block_max;
|
||||
scale_out[blockIdx.x] = absmax / 127.f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float const tmp_scale = 127.0f / block_absmax_val;
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
|
||||
}
|
||||
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
|
||||
|
||||
// 2. quantize
|
||||
vectorize_with_alignment<16>(
|
||||
row_in, row_out, hidden_size, tid, stride,
|
||||
[=] __device__(int8_t& dst, const scalar_t& src) {
|
||||
dst = float_to_int8_rn(static_cast<float>(src) * inv_s);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_type, typename azp_type>
|
||||
// MinMax structure to hold min and max values in one go
|
||||
struct MinMax {
|
||||
float min, max;
|
||||
|
||||
__host__ __device__ MinMax()
|
||||
: min(std::numeric_limits<float>::max()),
|
||||
max(std::numeric_limits<float>::lowest()) {}
|
||||
|
||||
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
|
||||
|
||||
// add a value to the MinMax
|
||||
__host__ __device__ MinMax& operator+=(float v) {
|
||||
min = fminf(min, v);
|
||||
max = fmaxf(max, v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// merge two MinMax objects
|
||||
__host__ __device__ MinMax& operator&=(const MinMax& other) {
|
||||
min = fminf(min, other.min);
|
||||
max = fmaxf(max, other.max);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ inline MinMax operator+(MinMax a, float v) {
|
||||
return a += v;
|
||||
}
|
||||
__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) {
|
||||
return a &= b;
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_t, typename azp_t>
|
||||
__global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type* scale, azp_type* azp, const int hidden_size) {
|
||||
int64_t const token_idx = blockIdx.x;
|
||||
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
|
||||
scale_t* scale_out, azp_t* azp_out, const int hidden_size) {
|
||||
const int tid = threadIdx.x;
|
||||
const int stride = blockDim.x;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
|
||||
// Must be performed using 64-bit math to avoid integer overflow.
|
||||
out += token_idx * hidden_size;
|
||||
input += token_idx * hidden_size;
|
||||
const scalar_t* row_in = input + token_idx * hidden_size;
|
||||
int8_t* row_out = output + token_idx * hidden_size;
|
||||
|
||||
// Scan for the min and max value for this token
|
||||
float max_val = std::numeric_limits<float>::min();
|
||||
float min_val = std::numeric_limits<float>::max();
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
auto val = static_cast<float>(input[i]);
|
||||
max_val = std::max(max_val, val);
|
||||
min_val = std::min(min_val, val);
|
||||
// 1. calculate min & max
|
||||
MinMax thread_mm;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
thread_mm += static_cast<float>(row_in[i]);
|
||||
}
|
||||
|
||||
// Reduce the max and min values across the block
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
|
||||
__syncthreads(); // Make sure min doesn't mess with max shared memory
|
||||
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
|
||||
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
|
||||
__shared__ scale_type scale_sh;
|
||||
__shared__ azp_type azp_sh;
|
||||
MinMax mm = BlockReduce(tmp).Reduce(
|
||||
thread_mm,
|
||||
[] __device__(MinMax a, const MinMax& b) {
|
||||
a &= b;
|
||||
return a;
|
||||
},
|
||||
blockDim.x);
|
||||
|
||||
// Compute the scale and zero point and store them, only on the first thread
|
||||
if (threadIdx.x == 0) {
|
||||
float const scale_val = (max_val - min_val) / 255.0f;
|
||||
// Use rounding to even (same as torch.round)
|
||||
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
|
||||
auto const azp_val = static_cast<azp_type>(azp_float);
|
||||
|
||||
// Store the scale and azp into shared and global
|
||||
scale[token_idx] = scale_sh = scale_val;
|
||||
azp[token_idx] = azp_sh = azp_val;
|
||||
__shared__ float scale_sh;
|
||||
__shared__ azp_t azp_sh;
|
||||
if (tid == 0) {
|
||||
float s = (mm.max - mm.min) / 255.f;
|
||||
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
|
||||
scale_sh = s;
|
||||
azp_sh = azp_t(zp);
|
||||
scale_out[blockIdx.x] = s;
|
||||
azp_out[blockIdx.x] = azp_sh;
|
||||
}
|
||||
|
||||
// Wait for the scale and azp to be computed
|
||||
__syncthreads();
|
||||
|
||||
float const scale_val = scale_sh;
|
||||
azp_type const azp_val = azp_sh;
|
||||
const float inv_s = 1.f / scale_sh;
|
||||
const azp_t azp = azp_sh;
|
||||
|
||||
// Quantize the values
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
auto const val = static_cast<float>(input[i]);
|
||||
auto const quant_val =
|
||||
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
|
||||
out[i] = quant_val;
|
||||
}
|
||||
// 2. quantize
|
||||
vectorize_with_alignment<16>(
|
||||
row_in, row_out, hidden_size, tid, stride,
|
||||
[=] __device__(int8_t& dst, const scalar_t& src) {
|
||||
const auto v = static_cast<float>(src) * inv_s;
|
||||
dst = int32_to_int8(float_to_int32_rn(v) + azp);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -247,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
dim3 const block(std::min(hidden_size, 256));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||
@@ -278,7 +316,7 @@ void dynamic_scaled_int8_quant(
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
dim3 const block(std::min(hidden_size, 256));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||
|
||||
75
csrc/quantization/vectorization_utils.cuh
Normal file
75
csrc/quantization/vectorization_utils.cuh
Normal file
@@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
#include "vectorization.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
|
||||
struct DefaultVecOp {
|
||||
ScaOp scalar_op;
|
||||
|
||||
__device__ __forceinline__ void operator()(
|
||||
vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
scalar_op(dst.val[i], src.val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename OutT, typename VecOp,
|
||||
typename ScaOp>
|
||||
__device__ inline void vectorize_with_alignment(
|
||||
const InT* in, OutT* out, int len, int tid, int stride,
|
||||
VecOp&& vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
|
||||
ScaOp&& scalar_op) { // InT -> OutT
|
||||
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||
"VEC_SIZE must be a positive power-of-two");
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
||||
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
||||
prefix_elems /= sizeof(InT);
|
||||
prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16
|
||||
|
||||
// 1. prefill the when it is unsafe to vectorize
|
||||
for (int i = tid; i < prefix_elems; i += stride) {
|
||||
scalar_op(out[i], in[i]);
|
||||
}
|
||||
|
||||
in += prefix_elems;
|
||||
out += prefix_elems;
|
||||
len -= prefix_elems;
|
||||
|
||||
int num_vec = len / VEC_SIZE;
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||
|
||||
// 2. vectorize the main part
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
vec_op(tmp, v_in[i]);
|
||||
v_out[i] = tmp;
|
||||
}
|
||||
|
||||
// 3. handle the tail
|
||||
int tail_start = num_vec * VEC_SIZE;
|
||||
for (int i = tid + tail_start; i < len; i += stride) {
|
||||
scalar_op(out[i], in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
|
||||
__device__ __forceinline__ void vectorize_with_alignment(const InT* in,
|
||||
OutT* out, int len,
|
||||
int tid, int stride,
|
||||
ScaOp&& scalar_op) {
|
||||
using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>;
|
||||
vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
Reference in New Issue
Block a user