[Perf] Tune scaled_fp8_quant by increasing vectorization (#18844)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-06-03 16:48:25 -04:00
committed by GitHub
parent bdf13965ab
commit e31446b6c8
4 changed files with 118 additions and 113 deletions

View File

@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
// aligned at 32-byte and 16-byte addresses respectively.
bool const can_vectorize = hidden_size % 16 == 0;
float absmax_val = 0.0f;
if (can_vectorize) {
@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
absmax_val = fmaxf(absmax_val, fabsf(x));
}
}
using BlockReduce = cub::BlockReduce<float, 1024>;
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float token_scale;
if (tid == 0) {
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val());
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val());
scale[token_idx] = token_scale;
}
__syncthreads();
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
dim3 const grid(num_tokens);
dim3 const block(block_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
int const block_size = 256;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
dim3 const block(std::min(hidden_size, block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();