dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -13,6 +13,28 @@ namespace vllm {
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
template <typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
||||
return c10::Float8_e4m3fn(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
||||
__hip_fp8_e4m3::__default_interpret),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
||||
__hip_fp8_e4m3_fnuz::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||
return x;
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
@@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
float const min_scaling_factor =
|
||||
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
|
||||
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
@@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
||||
scalar_t const* __restrict__ token_input = &input[offset];
|
||||
FP8_TYPE* __restrict__ token_output = &out[offset];
|
||||
fp8_type* __restrict__ token_output = &out[offset];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
@@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
||||
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
min_scaling_factor);
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion<false>(
|
||||
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
|
||||
static_cast<float>(token_input[i]), token_scale);
|
||||
}
|
||||
}
|
||||
@@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size);
|
||||
input.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -7,18 +7,52 @@
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
||||
#else
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include "amd/quant_utils.cuh"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
||||
#define MAYBE_HOST_DEVICE
|
||||
#endif
|
||||
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
|
||||
|
||||
// Determines the preferred FP8 type for the current platform.
|
||||
// Note that for CUDA this just returns true,
|
||||
// but on ROCm it will check device props.
|
||||
static bool is_fp8_ocp() {
|
||||
#ifndef USE_ROCM
|
||||
return true;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx94");
|
||||
return substring == std::string::npos;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct fp8_e4m3_adjusted_max;
|
||||
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
|
||||
static constexpr c10::Float8_e4m3fn val() {
|
||||
return std::numeric_limits<c10::Float8_e4m3fn>::max();
|
||||
}
|
||||
};
|
||||
|
||||
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
||||
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
|
||||
static constexpr c10::Float8_e4m3fnuz val() {
|
||||
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
|
||||
fp8_e4m3_adjusted_max<T>::val();
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
template <bool is_scale_inverted, typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
@@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
|
||||
fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
@@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
|
||||
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
using float8x4_t = q8x4_t<FP8_TYPE>;
|
||||
using float8x4_t = q8x4_t<fp8_type>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
@@ -141,22 +173,22 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
Reference in New Issue
Block a user