[Perf] Optimize Vectorization Utils for Int 8 Quantization Kernels (#20331)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
|||||||
|
|
||||||
// calculate for absmax
|
// calculate for absmax
|
||||||
float thread_max = 0.f;
|
float thread_max = 0.f;
|
||||||
for (int i = tid; i < hidden_size; i += stride) {
|
vectorize_read_with_alignment<16>(
|
||||||
const auto v = fabsf(static_cast<float>(row_in[i]));
|
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||||
|
const float v = fabsf(static_cast<float>(src));
|
||||||
thread_max = fmaxf(thread_max, v);
|
thread_max = fmaxf(thread_max, v);
|
||||||
}
|
});
|
||||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||||
__shared__ typename BlockReduce::TempStorage tmp;
|
__shared__ typename BlockReduce::TempStorage tmp;
|
||||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||||
@@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
|||||||
|
|
||||||
// 1. calculate min & max
|
// 1. calculate min & max
|
||||||
MinMax thread_mm;
|
MinMax thread_mm;
|
||||||
for (int i = tid; i < hidden_size; i += stride) {
|
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
||||||
thread_mm += static_cast<float>(row_in[i]);
|
[&] __device__(const scalar_t& src) {
|
||||||
}
|
thread_mm += static_cast<float>(src);
|
||||||
|
});
|
||||||
|
|
||||||
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
||||||
__shared__ typename BlockReduce::TempStorage tmp;
|
__shared__ typename BlockReduce::TempStorage tmp;
|
||||||
|
|||||||
@@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
|
|||||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
||||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||||
|
|
||||||
|
// fast path when the whole region is already aligned
|
||||||
|
// Note: currently the output is guaranteed to be same as the input, so we
|
||||||
|
// don't check it here, comments here just for future reference.
|
||||||
|
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||||
|
if (can_vec) {
|
||||||
|
int num_vec = len / VEC_SIZE;
|
||||||
|
|
||||||
|
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||||
|
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||||
|
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||||
|
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||||
|
|
||||||
|
for (int i = tid; i < num_vec; i += stride) {
|
||||||
|
vout_t tmp;
|
||||||
|
vec_op(tmp, v_in[i]);
|
||||||
|
v_out[i] = tmp;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
||||||
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
||||||
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
||||||
@@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
|
|||||||
std::forward<ScaOp>(scalar_op));
|
std::forward<ScaOp>(scalar_op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||||
|
struct DefaultReadVecOp {
|
||||||
|
ScaOp scalar_op;
|
||||||
|
|
||||||
|
__device__ __forceinline__ void operator()(
|
||||||
|
const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||||
|
scalar_op(src.val[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// read-only version: iterate over the input with alignment guarantees
|
||||||
|
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
|
||||||
|
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
|
||||||
|
int tid, int stride,
|
||||||
|
VecOp&& vec_op,
|
||||||
|
ScaOp&& scalar_op) {
|
||||||
|
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||||
|
"VEC_SIZE must be a positive power-of-two");
|
||||||
|
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
|
||||||
|
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||||
|
|
||||||
|
// fast path when the whole region is already aligned
|
||||||
|
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||||
|
if (can_vec) {
|
||||||
|
int num_vec = len / VEC_SIZE;
|
||||||
|
|
||||||
|
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||||
|
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||||
|
|
||||||
|
for (int i = tid; i < num_vec; i += stride) {
|
||||||
|
vec_op(v_in[i]);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int misalignment_offset = addr & (WIDTH - 1);
|
||||||
|
int alignment_bytes = WIDTH - misalignment_offset;
|
||||||
|
int prefix_elems = alignment_bytes & (WIDTH - 1);
|
||||||
|
prefix_elems /= sizeof(InT);
|
||||||
|
prefix_elems = min(prefix_elems, len);
|
||||||
|
|
||||||
|
// 1. handle the possibly unaligned prefix with scalar access.
|
||||||
|
for (int i = tid; i < prefix_elems; i += stride) {
|
||||||
|
scalar_op(in[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
in += prefix_elems;
|
||||||
|
len -= prefix_elems;
|
||||||
|
|
||||||
|
int num_vec = len / VEC_SIZE;
|
||||||
|
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||||
|
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||||
|
|
||||||
|
// 2. vectorized traversal of the main aligned region.
|
||||||
|
for (int i = tid; i < num_vec; i += stride) {
|
||||||
|
vec_op(v_in[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. handle remaining tail elements.
|
||||||
|
int tail_start = num_vec * VEC_SIZE;
|
||||||
|
for (int i = tid + tail_start; i < len; i += stride) {
|
||||||
|
scalar_op(in[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// overload that requires only a scalar_op
|
||||||
|
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||||
|
__device__ __forceinline__ void vectorize_read_with_alignment(
|
||||||
|
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
|
||||||
|
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
|
||||||
|
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
|
||||||
|
std::forward<ScaOp>(scalar_op));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|||||||
Reference in New Issue
Block a user