[Perf] Tune scaled_fp8_quant by increasing vectorization (#18844)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -140,6 +140,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
// sum of squares
|
||||
float ss = 0.0f;
|
||||
|
||||
const int VEC_SIZE = 4;
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
@@ -147,22 +148,23 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
x.x = static_cast<float>(in.x);
|
||||
x.y = static_cast<float>(in.y);
|
||||
x.z = static_cast<float>(in.z);
|
||||
x.w = static_cast<float>(in.w);
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
x.x += static_cast<float>(r.x);
|
||||
x.y += static_cast<float>(r.y);
|
||||
x.z += static_cast<float>(r.z);
|
||||
x.w += static_cast<float>(r.w);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
|
||||
ss += x.x * x.x;
|
||||
ss += x.y * x.y;
|
||||
ss += x.z * x.z;
|
||||
ss += x.w * x.w;
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
ss += x.val[j] * x.val[j];
|
||||
}
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
@@ -203,6 +205,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
const int VEC_SIZE = 4;
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
|
||||
@@ -212,26 +215,25 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
x.x = static_cast<float>(in.x);
|
||||
x.y = static_cast<float>(in.y);
|
||||
x.z = static_cast<float>(in.z);
|
||||
x.w = static_cast<float>(in.w);
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
x.x += static_cast<float>(r.x);
|
||||
x.y += static_cast<float>(r.y);
|
||||
x.z += static_cast<float>(r.z);
|
||||
x.w += static_cast<float>(r.w);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.x * rms) * w.x));
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.y * rms) * w.y));
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.z * rms) * w.z));
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.w * rms) * w.w));
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
block_absmax_val_maybe =
|
||||
fmaxf(block_absmax_val_maybe,
|
||||
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
|
||||
}
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
@@ -282,6 +284,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
const int VEC_SIZE = 4;
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
|
||||
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
|
||||
@@ -292,33 +295,31 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
x.x = static_cast<float>(in.x);
|
||||
x.y = static_cast<float>(in.y);
|
||||
x.z = static_cast<float>(in.z);
|
||||
x.w = static_cast<float>(in.w);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] = static_cast<float>(in.val[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
x.x += static_cast<float>(r.x);
|
||||
x.y += static_cast<float>(r.y);
|
||||
x.z += static_cast<float>(r.z);
|
||||
x.w += static_cast<float>(r.w);
|
||||
// Update residual
|
||||
r.x = static_cast<scalar_t>(x.x);
|
||||
r.y = static_cast<scalar_t>(x.y);
|
||||
r.z = static_cast<scalar_t>(x.z);
|
||||
r.w = static_cast<scalar_t>(x.w);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
x.val[j] += static_cast<float>(r.val[j]);
|
||||
}
|
||||
// Update residual
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
r.val[j] = static_cast<scalar_t>(x.val[j]);
|
||||
}
|
||||
vec_residual[i] = r;
|
||||
}
|
||||
|
||||
q8x4_t<scalar_out_t> out;
|
||||
out.x = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.x * rms) * w.x, scale);
|
||||
out.y = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.y * rms) * w.y, scale);
|
||||
out.z = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.z * rms) * w.z, scale);
|
||||
out.w = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.w * rms) * w.w, scale);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale);
|
||||
}
|
||||
vec_output[i] = out;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user