dynamic distpatch of fp8 kernels (#14245)

Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-03-11 07:54:56 -07:00
committed by GitHub
parent 08a1a1121d
commit a1c8f3796c
25 changed files with 292 additions and 159 deletions

View File

@@ -11,8 +11,8 @@
namespace vllm {
template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
template <typename scalar_t, typename fp8_type>
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
@@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}
template <typename scalar_t>
template <typename scalar_t, typename fp8_type>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
fp8_type* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
float const min_scaling_factor =
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
@@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];
fp8_type* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
@@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();
@@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion<false>(
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
static_cast<float>(token_input[i]), token_scale);
}
}
@@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
});
}
@@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::segmented_max_reduction<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
});
}
@@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
input.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
});
}