[Perf] Tune scaled_fp8_quant by increasing vectorization (#18844)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
}
|
||||
|
||||
float r =
|
||||
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
@@ -65,7 +65,7 @@ template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
__shared__ float cache[256];
|
||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
@@ -73,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = max(tmp, fabs(x));
|
||||
tmp = fmaxf(tmp, fabsf(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
@@ -100,25 +100,27 @@ template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
constexpr size_t VEC_SIZE = 16;
|
||||
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
// num_elems / VEC_SIZE (which is 16)
|
||||
int64_t const num_vec_elems = num_elems >> 4;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
#pragma unroll
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||
scalarxN_t in_vec = vectorized_in[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
absmax_val = max(absmax_val, fabs(input[i]));
|
||||
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
||||
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(input[i]));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
@@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
using float8x4_t = q8x4_t<fp8_type>;
|
||||
constexpr size_t VEC_SIZE = 16;
|
||||
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
||||
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
// num_elems / VEC_SIZE (which is 16)
|
||||
int64_t const num_vec_elems = num_elems >> 4;
|
||||
|
||||
#pragma unroll 4
|
||||
#pragma unroll
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
scalarxN_t in_vec = vectorized_in[i];
|
||||
float8xN_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.val[j]), scale);
|
||||
}
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
||||
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user