671
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
Normal file
671
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
Normal file
@@ -0,0 +1,671 @@
|
||||
#pragma once
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../../attention/attention_dtypes.h"
|
||||
|
||||
namespace vllm {
|
||||
#ifdef USE_ROCM
|
||||
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
template <typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
|
||||
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
|
||||
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
|
||||
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
|
||||
// the new HW cvt with something reasonable that doesn't rely on the
|
||||
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
return c10::Float8_e4m3fn(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
||||
__hip_fp8_e4m3::__default_interpret),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
#else
|
||||
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
|
||||
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
||||
__hip_fp8_e4m3_fnuz::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||
const float scale) {
|
||||
return x;
|
||||
}
|
||||
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
using fp8_type = __hip_fp8_e4m3;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3;
|
||||
#else
|
||||
using fp8_type = __hip_fp8_e4m3_fnuz;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
|
||||
#endif
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8));
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_
|
||||
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a),
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, float2>(const float2& a) {
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
// Float4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
// float2 -> bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
|
||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> bfloat162x2
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
|
||||
bf16_4_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> bfloat162x4
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||
bf16_8_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
b.z = __float22bfloat162_rn(a.z);
|
||||
b.w = __float22bfloat162_rn(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||
precision domains
|
||||
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization(
|
||||
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
|
||||
scale => HP
|
||||
|
||||
*/
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8) * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
float scale) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y =
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t
|
||||
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2) * scale;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_
|
||||
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
|
||||
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
return {res.x.x, res.x.y, res.y.x, res.y.y};
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
||||
__half_raw res;
|
||||
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
tmp.h2r.x.data *= scale;
|
||||
tmp.h2r.y.data *= scale;
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
tmp.data /= scale;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// halfx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
tmp.h2r.x.data /= scale;
|
||||
tmp.h2r.y.data /= scale;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// half2x2 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// half2x4 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint2 ui2[2];
|
||||
uint4 ui4;
|
||||
} tmp;
|
||||
tmp.ui4 = a;
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
|
||||
const __nv_bfloat162& a, float scale) {
|
||||
union {
|
||||
uint8_t ui8[2];
|
||||
uint16_t ui16;
|
||||
} tmp;
|
||||
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
|
||||
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
|
||||
return tmp.ui16;
|
||||
}
|
||||
|
||||
// bf16x4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// bf16x8 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// floatx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
|
||||
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// floatx4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout convert(const Tin& x) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return vec_conversion<Tout, Tin>(x);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace fp8
|
||||
#endif // USE_ROCM
|
||||
} // namespace vllm
|
||||
252
csrc/quantization/w8a8/fp8/common.cu
Normal file
252
csrc/quantization/w8a8/fp8/common.cu
Normal file
@@ -0,0 +1,252 @@
|
||||
#include "common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel_strided(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x; // one token per block
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float inv_scale = 1.0f / (*scale);
|
||||
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
inv_scale);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction_strided(
|
||||
float* __restrict__ scale, const scalar_t* __restrict__ input,
|
||||
int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
|
||||
__shared__ float cache[256];
|
||||
const int tid = threadIdx.x;
|
||||
int64_t token_idx = blockIdx.x;
|
||||
|
||||
// one block per token. Guard in case gridDim.x > num_tokens.
|
||||
if (token_idx >= num_tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scalar_t* row_ptr = input + token_idx * in_row_stride;
|
||||
|
||||
// each thread scans elements of the row in a strided fashion.
|
||||
float thread_max = 0.0f;
|
||||
for (int e = tid; e < hidden_size; e += blockDim.x) {
|
||||
float v = fabsf(static_cast<float>(row_ptr[e]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
|
||||
cache[tid] = thread_max;
|
||||
__syncthreads();
|
||||
|
||||
// parallel reduction to find row max.
|
||||
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
|
||||
if (tid < offset) {
|
||||
cache[tid] = fmaxf(cache[tid], cache[tid + offset]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// thread 0 updates global scale (per-tensor) atomically.
|
||||
if (tid == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel_strided_dynamic(
|
||||
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
|
||||
int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const scalar_t* token_in = input + token_idx * in_row_stride;
|
||||
fp8_type* token_out = out + token_idx * out_row_stride;
|
||||
|
||||
const float reciprocal_scale = 1.0f / (*scale);
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
|
||||
reciprocal_scale);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input, const float* __restrict__ scale_ub,
|
||||
int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t in_offset = static_cast<int64_t>(token_idx) * in_row_stride;
|
||||
int64_t out_offset = static_cast<int64_t>(token_idx) * out_row_stride;
|
||||
const scalar_t* token_in = input + in_offset;
|
||||
fp8_type* token_out = out + out_offset;
|
||||
|
||||
// 1) per-token absmax
|
||||
float absmax_val = 0.f;
|
||||
vectorize_read_with_alignment<16>(
|
||||
token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) {
|
||||
absmax_val = fmaxf(absmax_val, fabsf(static_cast<float>(v)));
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
const float block_max =
|
||||
BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max;
|
||||
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
|
||||
min_scaling_factor<fp8_type>::val());
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 2) quantize
|
||||
vectorize_with_alignment<16>(
|
||||
token_in, token_out, hidden_size, tid, blockDim.x,
|
||||
[=] __device__(fp8_type & dst, const scalar_t& src) {
|
||||
dst = scaled_fp8_conversion<false, fp8_type>(static_cast<float>(src),
|
||||
token_scale);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(block_size);
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// scale tensor should be initialised to <=0 before reduction
|
||||
AT_CUDA_CHECK(
|
||||
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size, in_row_stride,
|
||||
static_cast<int64_t>(num_tokens));
|
||||
|
||||
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), hidden_size, in_row_stride,
|
||||
out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"last dimension of input must be contiguous");
|
||||
TORCH_CHECK(out.stride(-1) == 1,
|
||||
"last dimension of output must be contiguous");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
const int block_size = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, block_size));
|
||||
|
||||
const int64_t in_row_stride = input.stride(-2);
|
||||
const int64_t out_row_stride = out.stride(-2);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
|
||||
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size, in_row_stride, out_row_stride);
|
||||
});
|
||||
});
|
||||
}
|
||||
58
csrc/quantization/w8a8/fp8/common.cuh
Normal file
58
csrc/quantization/w8a8/fp8/common.cuh
Normal file
@@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
#include "quantization/utils.cuh"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include "amd/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
// Determines the preferred FP8 type for the current platform.
|
||||
// Note that for CUDA this just returns true,
|
||||
// but on ROCm it will check device props.
|
||||
static bool is_fp8_ocp() {
|
||||
#ifndef USE_ROCM
|
||||
return true;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx94");
|
||||
return substring == std::string::npos;
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted, typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r =
|
||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
573
csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh
Normal file
573
csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh
Normal file
@@ -0,0 +1,573 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../../../attention/attention_dtypes.h"
|
||||
#include <assert.h>
|
||||
#include <float.h>
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace vllm {
|
||||
#ifndef USE_ROCM
|
||||
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout
|
||||
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
|
||||
tmp.u32[1] =
|
||||
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
|
||||
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
|
||||
res.y =
|
||||
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float
|
||||
vec_conversion<float, uint8_t>(const uint8_t &a,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
|
||||
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
|
||||
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
|
||||
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
|
||||
const float &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
|
||||
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
|
||||
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
|
||||
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
|
||||
const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
|
||||
const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
|
||||
const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low
|
||||
precision domains Convention of the scale in API, e.g: FP8_data =
|
||||
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
|
||||
Dequant(FP8) * scale => HP
|
||||
*/
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(
|
||||
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
return float_to_half(half_to_float(tmp.x) * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
|
||||
tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
|
||||
tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale, fp8_type);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4
|
||||
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp * scale);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
|
||||
fp8_type);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
|
||||
scale, fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
|
||||
fp8_type);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
scale, fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
|
||||
const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
|
||||
uint16_t tmp = res.x;
|
||||
|
||||
// half -> float
|
||||
return half_to_float(tmp) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
|
||||
fp8_type);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
|
||||
const uint2& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
|
||||
const uint16_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||
__NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
|
||||
const float& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
|
||||
const uint32_t& a, const float scale,
|
||||
const __nv_fp8_interpretation_t fp8_type) {
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout convert(const Tin& x) {
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
|
||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
#ifdef ENABLE_FP8
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
|
||||
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
|
||||
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the data type of the key and value cache. The FN is a macro that calls a
|
||||
// function with template<typename scalar_t, typename cache_t,
|
||||
// Fp8KVCacheDataType kv_dt>.
|
||||
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
|
||||
if (KV_DTYPE == "auto") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_e5m2") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace fp8
|
||||
#endif // not USE_ROCM
|
||||
} // namespace vllm
|
||||
Reference in New Issue
Block a user