[Kernel] Dynamic Per-Token Activation Quantization (#5037)

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Dipika Sikka
2024-06-07 12:36:26 -04:00
committed by GitHub
parent dc49fb892c
commit ca3ea51bde
12 changed files with 439 additions and 75 deletions

View File

@@ -3,6 +3,7 @@
#include <cmath>
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
@@ -27,17 +28,48 @@ namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
const scale_type* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
scale_type scale = *scale_ptr;
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
}
}
template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / 127.0f;
}
__syncthreads();
float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
}
}
} // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
@@ -47,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@@ -60,3 +92,24 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
scale.data_ptr<float>(), hidden_size);
});
}
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
});
}