[ROCm][Quantization][Kernel] Using HIP FP8 header (#12593)

This commit is contained in:
Gregory Shtrasberg
2025-02-25 03:39:59 -05:00
committed by GitHub
parent 2f42a4888c
commit aabeb2688f
6 changed files with 267 additions and 634 deletions

View File

@@ -1,13 +1,11 @@
#pragma once
#include "hip_float8.h"
#include <hip/hip_fp8.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/dtype_fp8.cuh"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
#include "../../../attention/attention_dtypes.h"
namespace vllm {
#ifdef USE_ROCM
@@ -26,40 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return x;
}
#if HIP_FP8_TYPE_FNUZ
using fp8_type = __hip_fp8_e4m3_fnuz;
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
#elif HIP_FP8_TYPE_OCP
using fp8_type = __hip_fp8_e4m3;
using fp8x2_type = __hip_fp8x2_e4m3;
#endif
// fp8 -> half
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8);
return res.x;
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0];
tmp.h2r.y.data = f2[1];
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
@@ -92,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16;
template <>
__inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f);
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8));
}
using __nv_bfloat162 = __hip_bfloat162;
@@ -136,27 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// fp8 -> float
template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8);
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0];
res.y = f2[1];
return res;
#else
float2 res;
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
return res;
#endif
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2);
}
// fp8x4 -> float4
@@ -169,6 +149,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
@@ -189,33 +178,36 @@ __inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
__half_raw tmp;
tmp.x = a;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
hip_fp8 f8{static_cast<float>(tmp.data)};
return f8.data;
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
union {
uint32_t ui32;
__half2_raw h2r;
} tmp;
tmp.ui32 = a;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
hip_fp8 res{__bfloat162float(a)};
return res.data;
return __hip_cvt_float_to_fp8(__bfloat162float(a),
fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
hip_fp8 f8(a);
return f8.data;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float2 -> half2
@@ -307,90 +299,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
*/
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8) * scale;
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0] * scale;
tmp.h2r.y.data = f2[1] * scale;
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] =
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
static_cast<uint8_t>(a >> 8U), scale);
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2;
}
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
const float scale) {
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f * scale);
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8) * scale);
}
using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
const float scale) {
float scale) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.y =
@@ -400,8 +324,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
const uint32_t& a, const float scale) {
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
@@ -412,7 +336,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
@@ -427,29 +351,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, const float scale) {
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8) * scale;
const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0] * scale;
res.y = f2[1] * scale;
return res;
#else
float2 res;
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
scale);
return res;
#endif
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2) * scale;
}
// fp8x4 -> float4
@@ -462,10 +376,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
return {res.x.x, res.x.y, res.y.x, res.y.y};
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
@@ -477,44 +399,184 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
return res;
}
/* Quantize(HP / scale) => FP8 */
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
__half_raw res;
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
return res.x;
}
// TODO(Hai): vectorized to add
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
__half2_raw h2r =
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
tmp.h2r.x.data *= scale;
tmp.h2r.y.data *= scale;
return tmp.ui32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
__half_raw tmp;
tmp.x = a;
tmp.data /= scale;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
return f8.data;
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
union {
uint32_t ui32;
__half2_raw h2r;
} tmp;
tmp.ui32 = a;
tmp.h2r.x.data /= scale;
tmp.h2r.y.data /= scale;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
return tmp.ui32;
}
// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) {
union {
uint2 ui2[2];
uint4 ui4;
} tmp;
tmp.ui4 = a;
uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
return res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, const float scale) {
hip_fp8 res{__bfloat162float(a) / scale};
return res.data;
const __nv_bfloat16& a, float scale) {
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
return tmp.ui16;
}
// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
return tmp.ui32;
}
// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
return res;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
hip_fp8 f8(a / scale);
return f8.data;
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// fp8x4 -> float4
// floatx2 -> fp8x2
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
return tmp.ui32;
}
#endif // ENABLE_FP8