[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin
2024-05-22 03:18:41 -04:00
committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
64 changed files with 6398 additions and 6790 deletions

View File

@@ -10,9 +10,9 @@ namespace vllm {
#ifndef USE_ROCM
namespace fp8 {
#ifdef ENABLE_FP8
#ifdef ENABLE_FP8
#if 0 // Disable the following code to reduce the binary size.
#if 0 // Disable the following code to reduce the binary size.
template <typename Tout, typename Tin>
__inline__ __device__ Tout
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
@@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
return (uint8_t)res;
#endif
#endif
}
// float -> fp8
@@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
from_float(b, a);
return b;
}
#endif
#endif
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains Convention of the scale in API, e.g: FP8_data =
@@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(
const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
return x;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
const uint8_t &a, const float scale,
const uint8_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
return float_to_half(half_to_float(tmp.x) * scale);
@@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
const uint16_t &a, const float scale,
const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint16_t u16[2];
@@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
const uint32_t &a, const float scale,
const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint2 u32x2;
@@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4
scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
union {
uint4 u64x2;
@@ -348,7 +348,7 @@ scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
const uint8_t &a, const float scale,
const uint8_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
@@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
const uint16_t &a, const float scale,
const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
@@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
const uint32_t &a, const float scale,
const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
@@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
const uint2 &a, const float scale,
const uint2& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
@@ -404,9 +404,8 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t &a, const float scale,
const uint8_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
uint16_t tmp = res.x;
@@ -418,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
// fp8x2 -> float2
template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
const uint16_t &a, const float scale,
const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
// fp8x2 -> half2
uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
@@ -429,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
const uint32_t &a, const float scale,
const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
@@ -441,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
const uint2 &a, const float scale,
const uint2& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
@@ -457,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
// half -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
const uint16_t &a, const float scale,
const uint16_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
@@ -467,21 +466,21 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16 &a, const float scale,
const __nv_bfloat16& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
#else
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
__NV_SATFINITE, fp8_type);
return (uint8_t)res;
#endif
#endif
}
// float -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
const float &a, const float scale,
const float& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
@@ -491,78 +490,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
// fp8x4 -> float4
template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
const uint32_t &a, const float scale,
const uint32_t& a, const float scale,
const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
#endif // ENABLE_FP8
#endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin &x) {
#if 0 // Disable the following code to reduce the binary size.
__inline__ __device__ Tout convert(const Tin& x) {
#if 0 // Disable the following code to reduce the binary size.
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
}
#endif
#endif
assert(false);
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
#ifdef ENABLE_FP8
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
}
#endif
#endif
assert(false);
}
// The following macro is used to dispatch the conversion function based on the
// data type of the key and value cache. The FN is a macro that calls a function
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
// The following macro is used to dispatch the conversion function based on
// the data type of the key and value cache. The FN is a macro that calls a
// function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
} // namespace fp8
#endif // not USE_ROCM
} // namespace vllm
} // namespace fp8
#endif // not USE_ROCM
} // namespace vllm