dynamic distpatch of fp8 kernels (#14245)

Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily
2025-03-11 07:54:56 -07:00
committed by GitHub
parent 08a1a1121d
commit a1c8f3796c
25 changed files with 292 additions and 159 deletions

View File

@@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());

View File

@@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
#endif
}
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<FP8_TYPE>(r);
template <typename fp8_type>
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
return static_cast<fp8_type>(r);
}
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
@@ -54,15 +56,16 @@ struct ScaledQuant<
};
template <typename quant_type_t, bool is_scale_inverted>
struct ScaledQuant<
quant_type_t, is_scale_inverted,
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
struct ScaledQuant<quant_type_t, is_scale_inverted,
typename std::enable_if_t<
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>>> {
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
float const scale) {
if constexpr (is_scale_inverted) {
return float_to_fp8(x * scale);
return float_to_fp8<quant_type_t>(x * scale);
} else {
return float_to_fp8(x / scale);
return float_to_fp8<quant_type_t>(x / scale);
}
}
};