diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 1aad6330c..7838f211c 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -5,7 +5,9 @@ #include -#ifdef USE_ROCM +#ifndef USE_ROCM + #include "nvidia/quant_utils.cuh" +#else #include "amd/quant_utils.cuh" #endif @@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, float r = fmaxf(-quant_type_max_v, fminf(x, quant_type_max_v)); #ifndef USE_ROCM - return static_cast(r); + // Use hardware cvt instruction for fp8 on nvidia + // Currently only support fp8_type = c10::Float8_e4m3fn + return fp8::vec_conversion(r); #else // Use hardware cvt instruction for fp8 on rocm return fp8::cvt_c10(r); diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index f8cd1dcba..5b9c2df84 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -12,13 +12,26 @@ namespace vllm { namespace fp8 { #ifdef ENABLE_FP8 - #if 0 // Disable the following code to reduce the binary size. template -__inline__ __device__ Tout -vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { +__inline__ __device__ Tout vec_conversion( + const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) { return x; } +// float -> c10::Float8_e4m3fn +template <> +__inline__ __device__ c10::Float8_e4m3fn +vec_conversion( + const float& a, const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return static_cast(a); + #else + return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type), + c10::Float8_e4m3fn::from_bits()); + #endif +} + + #if 0 // Disable the following code to reduce the binary size. // fp8 -> half template <> __inline__ __device__ uint16_t vec_conversion(