[ROCm][Quantization][Kernel] Using HIP FP8 header (#12593)
This commit is contained in:
committed by
GitHub
parent
2f42a4888c
commit
aabeb2688f
@@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include "amd/hip_float8.h"
|
||||
#include "amd/quant_utils.cuh"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
@@ -47,8 +47,10 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
|
||||
fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user