dynamic distpatch of fp8 kernels (#14245)

Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-03-11 07:54:56 -07:00
committed by GitHub
parent 08a1a1121d
commit a1c8f3796c
25 changed files with 292 additions and 159 deletions

View File

@@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
#endif
}
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<FP8_TYPE>(r);
template <typename fp8_type>
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
return static_cast<fp8_type>(r);
}
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
@@ -54,15 +56,16 @@ struct ScaledQuant<
};
template <typename quant_type_t, bool is_scale_inverted>
struct ScaledQuant<
quant_type_t, is_scale_inverted,
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
struct ScaledQuant<quant_type_t, is_scale_inverted,
typename std::enable_if_t<
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>>> {
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
float const scale) {
if constexpr (is_scale_inverted) {
return float_to_fp8(x * scale);
return float_to_fp8<quant_type_t>(x * scale);
} else {
return float_to_fp8(x / scale);
return float_to_fp8<quant_type_t>(x / scale);
}
}
};