[Performance] Fused blockwise quant RMS norm (#27883)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
token_scale = 1.0f / token_scale;
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
|
||||
has_residual>(
|
||||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
} else {
|
||||
// FP8 - Do not invert token_scale for exact match with FBGemm
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
|
||||
has_residual>(
|
||||
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,14 +76,52 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
token_scale = 1.0f / token_scale;
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
|
||||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
} else {
|
||||
// FP8 - Do not invert s_token_scale for exact match with FBGemm
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
|
||||
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
}
|
||||
}
|
||||
|
||||
// RMS norm + quant kernel
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
|
||||
bool is_scale_transposed = false, int32_t group_size = 0>
|
||||
__global__ void rms_norm_per_block_quant_kernel(
|
||||
scalar_out_t* __restrict__ out, // [..., hidden_size]
|
||||
float* __restrict__ scales, // [num_tokens, hidden_size / group_size]
|
||||
// or
|
||||
// [hidden_size / group_size, num_tokens]
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
float rms;
|
||||
// Compute RMS
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
vllm::vectorized::compute_rms<scalar_t, has_residual>(
|
||||
&rms, input, hidden_size, var_epsilon, residual);
|
||||
|
||||
// Compute Scale
|
||||
// Always able to vectorize due to constraints on hidden_size and group_size
|
||||
vllm::vectorized::compute_dynamic_per_token_scales<
|
||||
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
|
||||
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
// For int8, don't invert token_scale here: do it inside the norm_and_quant
|
||||
// kernel. We do it because particular elements of token_scale can be shared
|
||||
// between multiple threads, so this way, we avoid extra synchronization
|
||||
// overhead.
|
||||
vllm::vectorized::norm_and_quant<
|
||||
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
|
||||
has_residual, is_scale_transposed, group_size>(
|
||||
out, input, weight, rms, scales, hidden_size, residual);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Residual add + RMS norm + dynamic per token
|
||||
@@ -103,30 +142,19 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (residual.has_value()) {
|
||||
VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||
true>
|
||||
has_residual>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>());
|
||||
var_epsilon, hidden_size,
|
||||
has_residual ? residual->data_ptr<scalar_in_t>() : nullptr);
|
||||
});
|
||||
|
||||
} else {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||
false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, hidden_size, nullptr);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void rms_norm_dynamic_per_token_quant(
|
||||
@@ -157,3 +185,79 @@ void rms_norm_dynamic_per_token_quant(
|
||||
out, input, weight, scales, var_epsilon, scale_ub, residual);
|
||||
});
|
||||
}
|
||||
|
||||
// Residual add + RMS norm + dynamic per token
|
||||
void rms_norm_per_block_quant_dispatch(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& weight, // [hidden_size]
|
||||
torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or
|
||||
// [hidden_size / group_size, num_tokens]
|
||||
int32_t group_size,
|
||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||
std::optional<at::Tensor> const& scale_ub,
|
||||
std::optional<at::Tensor>& residual, bool is_scale_transposed) {
|
||||
int32_t hidden_size = input.size(-1);
|
||||
auto num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
const int max_block_size = (num_tokens <= 256) ? 512 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] {
|
||||
using scalar_in_t = scalar_t;
|
||||
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] {
|
||||
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] {
|
||||
vllm::rms_norm_per_block_quant_kernel<scalar_in_t, scalar_t,
|
||||
has_residual,
|
||||
transpose_scale, gs>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(),
|
||||
weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
var_epsilon, hidden_size,
|
||||
has_residual ? residual->data_ptr<scalar_in_t>()
|
||||
: nullptr);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& weight,
|
||||
torch::Tensor& scales, double const var_epsilon,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
std::optional<torch::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
}
|
||||
TORCH_CHECK(weight.dtype() == input.dtype());
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||
if (residual) {
|
||||
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
|
||||
}
|
||||
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
"Unsupported group size: ", group_size);
|
||||
|
||||
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
|
||||
var_epsilon, scale_ub, residual,
|
||||
is_scale_transposed);
|
||||
}
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
#include "../../cub_helpers.h"
|
||||
#include "../../cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@@ -43,62 +44,150 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
*rms = s_rms;
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid,
|
||||
int64_t thread_in_warp,
|
||||
int64_t reduced_elems) {
|
||||
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64);
|
||||
if constexpr (WARP_SIZE == 64) {
|
||||
if (thread_in_warp + 64 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 64]);
|
||||
}
|
||||
if (thread_in_warp + 32 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 32]);
|
||||
if (thread_in_warp + 16 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 16]);
|
||||
if (thread_in_warp + 8 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 8]);
|
||||
if (thread_in_warp + 4 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 4]);
|
||||
if (thread_in_warp + 2 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 2]);
|
||||
if (thread_in_warp + 1 < reduced_elems)
|
||||
val[tid] = fmaxf(val[tid], val[tid + 1]);
|
||||
return val[tid];
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
|
||||
bool is_scale_transposed = false>
|
||||
__device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0) {
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // Shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||
}
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
__syncthreads();
|
||||
if (group_size > 0) {
|
||||
__shared__ float s_max_vals[1024];
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
int64_t num_groups = hidden_size / group_size;
|
||||
int64_t const threads_per_group = blockDim.x / num_groups;
|
||||
int64_t const thread_in_group = threadIdx.x % threads_per_group;
|
||||
int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
|
||||
int64_t const thread_offset = group_offset + thread_in_group;
|
||||
int64_t const thread_end =
|
||||
min(group_offset + group_size, static_cast<int64_t>(hidden_size));
|
||||
for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||
}
|
||||
s_max_vals[threadIdx.x] = block_absmax_val_maybe;
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
int64_t const warp_size = WARP_SIZE;
|
||||
int64_t const num_warps = blockDim.x / warp_size;
|
||||
int64_t const warp_id = threadIdx.x / warp_size;
|
||||
int64_t const thread_in_warp = threadIdx.x % warp_size;
|
||||
int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
|
||||
for (auto i = 0; i < groups_per_warp; ++i) {
|
||||
int64_t const group_id = i * num_warps + warp_id;
|
||||
if (group_id < num_groups) {
|
||||
int64_t warp_start = group_id * threads_per_group;
|
||||
int64_t const start = warp_start + thread_in_warp;
|
||||
int64_t const warp_end = min(warp_start + threads_per_group,
|
||||
static_cast<int64_t>(hidden_size));
|
||||
for (auto j = start; j + warp_size < warp_end; j += warp_size) {
|
||||
s_max_vals[start] =
|
||||
fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
|
||||
}
|
||||
warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
|
||||
min(warp_end - warp_start, warp_size));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (thread_in_group == 0 && thread_offset < thread_end) {
|
||||
block_absmax_val_maybe = s_max_vals[threadIdx.x];
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
// Global output store
|
||||
if constexpr (is_scale_transposed) {
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
|
||||
blockIdx.x] = scale;
|
||||
} else {
|
||||
all_token_scales[blockIdx.x * num_groups +
|
||||
threadIdx.x / threads_per_group] = scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||
}
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // Shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false>
|
||||
bool has_residual = false, bool is_scale_transposed = false>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float const scale,
|
||||
float const rms, float* const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
scalar_t* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
@@ -109,8 +198,21 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
// Norm
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
// Quant
|
||||
// If groupwise is_scale_inverted is true, so we invert the scale here.
|
||||
int64_t scale_idx = 0;
|
||||
if (group_size > 0) {
|
||||
if constexpr (is_scale_transposed) {
|
||||
scale_idx = (i / group_size) * gridDim.x + blockIdx.x;
|
||||
} else {
|
||||
scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
|
||||
}
|
||||
}
|
||||
auto scale_val =
|
||||
(group_size > 0
|
||||
? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx])
|
||||
: *scale);
|
||||
output[token_offset + i] =
|
||||
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale);
|
||||
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale_val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,95 +280,191 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
|
||||
// Vectorized version of vllm::compute_dynamic_per_token_scales
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false,
|
||||
bool is_scale_transposed = false, int32_t group_size = 0>
|
||||
__device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
// Vectorized input/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec4_t<scalar_t> const* vec_weight =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
const int VEC_SIZE = 4;
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
// Vectorized input/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input = nullptr;
|
||||
vec4_t<scalar_t> const* vec_weight = nullptr;
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
|
||||
vec4_t<float> x;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
if constexpr (group_size > 0) {
|
||||
__shared__ float s_max_vals[1024];
|
||||
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
int64_t const num_groups = hidden_size / group_size;
|
||||
int64_t const threads_per_group = blockDim.x / num_groups;
|
||||
int64_t const thread_in_group = threadIdx.x % threads_per_group;
|
||||
int64_t const group_offset =
|
||||
threadIdx.x / threads_per_group * (group_size >> 2);
|
||||
int64_t const thread_offset = group_offset + thread_in_group;
|
||||
int64_t const thread_end = min(group_offset + (group_size >> 2),
|
||||
static_cast<int64_t>(hidden_size >> 2));
|
||||
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
int32_t const num_vec_elems = thread_end;
|
||||
|
||||
#pragma unroll 4
|
||||
for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
block_absmax_val_maybe =
|
||||
fmaxf(block_absmax_val_maybe,
|
||||
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
|
||||
}
|
||||
}
|
||||
|
||||
s_max_vals[threadIdx.x] = block_absmax_val_maybe;
|
||||
__syncthreads();
|
||||
|
||||
int64_t const warp_size = WARP_SIZE;
|
||||
int64_t const num_warps = blockDim.x / warp_size;
|
||||
int64_t const warp_id = threadIdx.x / warp_size;
|
||||
int64_t const thread_in_warp = threadIdx.x % warp_size;
|
||||
int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps;
|
||||
for (auto i = 0; i < groups_per_warp; ++i) {
|
||||
int64_t const group_id = i * num_warps + warp_id;
|
||||
if (group_id < num_groups) {
|
||||
int64_t warp_start = group_id * threads_per_group;
|
||||
int64_t const start = warp_start + thread_in_warp;
|
||||
int64_t const warp_end = min(warp_start + threads_per_group,
|
||||
static_cast<int64_t>(hidden_size));
|
||||
for (auto j = start; j + warp_size < warp_end; j += warp_size) {
|
||||
s_max_vals[start] =
|
||||
fmaxf(s_max_vals[start], s_max_vals[j + warp_size]);
|
||||
}
|
||||
warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp,
|
||||
min(warp_end - warp_start, warp_size));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (thread_in_group == 0 && thread_offset < thread_end) {
|
||||
block_absmax_val_maybe = s_max_vals[threadIdx.x];
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
// Global output store
|
||||
if constexpr (is_scale_transposed) {
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
|
||||
blockIdx.x] = scale;
|
||||
} else {
|
||||
all_token_scales[blockIdx.x * num_groups +
|
||||
threadIdx.x / threads_per_group] = scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
} else {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
int32_t const num_vec_elems = (hidden_size >> 2);
|
||||
|
||||
#pragma unroll 4
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
block_absmax_val_maybe =
|
||||
fmaxf(block_absmax_val_maybe,
|
||||
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
block_absmax_val_maybe =
|
||||
fmaxf(block_absmax_val_maybe,
|
||||
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // global output store
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // global output store
|
||||
}
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
*token_scale = s_token_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false>
|
||||
bool has_residual = false, bool is_scale_transposed = false,
|
||||
int32_t group_size = 0>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float const scale,
|
||||
float const rms, float* const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
@@ -311,10 +509,26 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
}
|
||||
|
||||
q8x4_t<scalar_out_t> out;
|
||||
|
||||
float scale_val;
|
||||
|
||||
if constexpr (group_size > 0) {
|
||||
int64_t const num_groups = hidden_size / group_size;
|
||||
int64_t scale_idx = 0;
|
||||
if constexpr (is_scale_transposed) {
|
||||
scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x;
|
||||
} else {
|
||||
scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
|
||||
}
|
||||
scale_val =
|
||||
is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx];
|
||||
} else {
|
||||
scale_val = *scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale);
|
||||
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale_val);
|
||||
}
|
||||
vec_output[i] = out;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user