refactor quant folder

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
yewentao256
2025-08-07 15:05:05 -07:00
parent 7e3a8dc906
commit f07e10e9bc
5 changed files with 25 additions and 63 deletions

View File

@@ -0,0 +1,215 @@
#include <ATen/cuda/CUDAContext.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include <cuda_fp8.h>
#include <torch/all.h>
#include "../../vectorization.cuh"
#include "../../vectorization_utils.cuh"
#include "../../../dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
return val;
}
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
__global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
scale_packed_t* __restrict__ output_s, const int group_size,
const int num_groups, const int groups_per_block, const float eps,
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
const int scale_stride = 0) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;
const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
using scale_element_t = float;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
const T* group_input = input + block_group_offset;
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
scale_element_t* scale_output;
if constexpr (IS_COLUMN_MAJOR) {
const int num_elems_per_pack =
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
const int row_idx = global_group_id / scale_num_rows_element;
const int col_idx_raw = global_group_id % scale_num_rows_element;
const int col_idx = col_idx_raw / num_elems_per_pack;
const int pack_idx = col_idx_raw % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack +
row_idx * num_elems_per_pack + pack_idx);
} else {
scale_output = output_s + global_group_id;
}
// shared memory to cache each group's data to avoid double DRAM reads.
extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;
constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;
// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};
vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax, lane_id);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}
scale_element_t y_s_quant = y_s;
if (lane_id == 0) {
*scale_output = y_s_quant;
}
__syncthreads();
// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};
vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
}
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
const int num_groups = input.numel() / group_size;
TORCH_CHECK(input.numel() % group_size == 0);
TORCH_CHECK(output_s.dim() == 2);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1;
if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
const int scale_num_rows = output_s.size(1);
const int scale_stride = output_s.stride(1);
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
size_t smem_bytes = \
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
if (is_column_major) { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} \
} else { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} \
} \
} while (0)
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
} else if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t);
}
}));
#undef LAUNCH_KERNEL
}
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}

View File

@@ -0,0 +1,12 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "../per_token_group_quant_8bit.h"
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}

View File

@@ -0,0 +1,302 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cmath>
#include "../../../dispatch_utils.h"
#include "../../vectorization_utils.cuh"
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static constexpr auto i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static constexpr auto i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
float dst = std::nearbyint(x);
// Replace std::clamp due to hip-clang issues
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}
static inline __device__ int32_t float_to_int32_rn(float x) {
#ifdef USE_ROCM
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
static constexpr auto i32_min_f = static_cast<float>(i32_min);
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
static constexpr auto i32_max_f = static_cast<float>(i32_max);
float dst = std::nearbyint(x);
if (dst >= i32_max_f) {
return i32_max;
}
if (dst <= i32_min_f) {
return i32_min;
}
return static_cast<int32_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int32_t&>(dst);
#endif
}
static inline __device__ int8_t int32_to_int8(int32_t x) {
#ifdef USE_ROCM
static constexpr auto i8_min =
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
static constexpr auto i8_max =
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// Replace std::clamp due to hip-clang issues
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}
namespace vllm {
template <typename scalar_t, typename scale_t>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
const scale_t* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const float scale = *scale_ptr;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
dst = float_to_int8_rn(static_cast<float>(src) / scale);
});
}
template <typename scalar_t, typename scale_t, typename azp_t>
__global__ void static_scaled_int8_azp_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const float scale = *scale_ptr;
const azp_t azp = *azp_ptr;
const float inv_s = 1.0f / scale;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
const auto v = static_cast<float>(src) * inv_s;
dst = int32_to_int8(float_to_int32_rn(v) + azp);
});
}
template <typename scalar_t, typename scale_t>
__global__ void dynamic_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
scale_t* scale_out, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
float thread_max = 0.f;
vectorize_read_with_alignment<16>(
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
const float v = fabsf(static_cast<float>(src));
thread_max = fmaxf(thread_max, v);
});
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
__shared__ float absmax;
if (tid == 0) {
absmax = block_max;
scale_out[blockIdx.x] = absmax / 127.f;
}
__syncthreads();
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
dst = float_to_int8_rn(static_cast<float>(src) * inv_s);
});
}
// MinMax structure to hold min and max values in one go
struct MinMax {
float min, max;
__host__ __device__ MinMax()
: min(std::numeric_limits<float>::max()),
max(std::numeric_limits<float>::lowest()) {}
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
__host__ __device__ MinMax& operator+=(float v) {
min = fminf(min, v);
max = fmaxf(max, v);
return *this;
}
__host__ __device__ MinMax& operator&=(const MinMax& other) {
min = fminf(min, other.min);
max = fmaxf(max, other.max);
return *this;
}
};
__host__ __device__ inline MinMax operator+(MinMax a, float v) {
return a += v;
}
__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) {
return a &= b;
}
template <typename scalar_t, typename scale_t, typename azp_t>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
scale_t* scale_out, azp_t* azp_out, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
MinMax thread_mm;
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
[&] __device__(const scalar_t& src) {
thread_mm += static_cast<float>(src);
});
using BlockReduce = cub::BlockReduce<MinMax, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
MinMax mm = BlockReduce(tmp).Reduce(
thread_mm,
[] __device__(MinMax a, const MinMax& b) {
a &= b;
return a;
},
blockDim.x);
__shared__ float scale_sh;
__shared__ azp_t azp_sh;
if (tid == 0) {
float s = (mm.max - mm.min) / 255.f;
float zp = nearbyintf(-128.f - mm.min / s);
scale_sh = s;
azp_sh = azp_t(zp);
scale_out[blockIdx.x] = s;
azp_out[blockIdx.x] = azp_sh;
}
__syncthreads();
const float inv_s = 1.f / scale_sh;
const azp_t azp = azp_sh;
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
const auto v = static_cast<float>(src) * inv_s;
dst = int32_to_int8(float_to_int32_rn(v) + azp);
});
}
} // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp || azp->numel() == 1);
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
if (!azp) {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
} else {
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
});
}
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(!azp || azp->is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 256));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
if (!azp) {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
} else {
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
});
}

View File

@@ -0,0 +1,9 @@
#pragma once
#include <torch/all.h>
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false);