[Perf] Tune scaled_fp8_quant by increasing vectorization (#18844)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ token_output = &out[offset];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 4 == 0;
|
||||
// aligned at 32-byte and 16-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 16 == 0;
|
||||
|
||||
float absmax_val = 0.0f;
|
||||
if (can_vectorize) {
|
||||
@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float const x = static_cast<float>(token_input[i]);
|
||||
absmax_val = max(absmax_val, fabs(x));
|
||||
absmax_val = fmaxf(absmax_val, fabsf(x));
|
||||
}
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
float const block_absmax_val_maybe =
|
||||
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
if (scale_ub) {
|
||||
token_scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
|
||||
min_scaling_factor<fp8_type>::val());
|
||||
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
|
||||
min_scaling_factor<fp8_type>::val());
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
int const block_size = 256;
|
||||
int const num_tokens = input.numel() / input.size(-1);
|
||||
int const num_elems = input.numel();
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(block_size);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
int const block_size = 256;
|
||||
int const num_tokens = input.numel() / input.size(-1);
|
||||
int const num_elems = input.numel();
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(block_size);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
int const block_size = 256;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
dim3 const block(std::min(hidden_size, block_size));
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
}
|
||||
|
||||
float r =
|
||||
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
@@ -65,7 +65,7 @@ template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
__shared__ float cache[256];
|
||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
@@ -73,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = max(tmp, fabs(x));
|
||||
tmp = fmaxf(tmp, fabsf(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
@@ -100,25 +100,27 @@ template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
constexpr size_t VEC_SIZE = 16;
|
||||
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
// num_elems / VEC_SIZE (which is 16)
|
||||
int64_t const num_vec_elems = num_elems >> 4;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
#pragma unroll
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||
scalarxN_t in_vec = vectorized_in[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
absmax_val = max(absmax_val, fabs(input[i]));
|
||||
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
||||
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(input[i]));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
@@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
using float8x4_t = q8x4_t<fp8_type>;
|
||||
constexpr size_t VEC_SIZE = 16;
|
||||
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
|
||||
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
// num_elems / VEC_SIZE (which is 16)
|
||||
int64_t const num_vec_elems = num_elems >> 4;
|
||||
|
||||
#pragma unroll 4
|
||||
#pragma unroll
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
scalarxN_t in_vec = vectorized_in[i];
|
||||
float8xN_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VEC_SIZE; ++j) {
|
||||
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.val[j]), scale);
|
||||
}
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
|
||||
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user