dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
|
||||
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<FP8_TYPE>(r);
|
||||
template <typename fp8_type>
|
||||
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
||||
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
return static_cast<fp8_type>(r);
|
||||
}
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
|
||||
@@ -54,15 +56,16 @@ struct ScaledQuant<
|
||||
};
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted>
|
||||
struct ScaledQuant<
|
||||
quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
|
||||
struct ScaledQuant<quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>>> {
|
||||
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||
float const scale) {
|
||||
if constexpr (is_scale_inverted) {
|
||||
return float_to_fp8(x * scale);
|
||||
return float_to_fp8<quant_type_t>(x * scale);
|
||||
} else {
|
||||
return float_to_fp8(x / scale);
|
||||
return float_to_fp8<quant_type_t>(x / scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user