dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
|
||||
torch::Tensor& scales, // [num_tokens]
|
||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
|
||||
|
||||
@@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
|
||||
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<FP8_TYPE>(r);
|
||||
template <typename fp8_type>
|
||||
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
||||
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
return static_cast<fp8_type>(r);
|
||||
}
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
|
||||
@@ -54,15 +56,16 @@ struct ScaledQuant<
|
||||
};
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted>
|
||||
struct ScaledQuant<
|
||||
quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
|
||||
struct ScaledQuant<quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>>> {
|
||||
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||
float const scale) {
|
||||
if constexpr (is_scale_inverted) {
|
||||
return float_to_fp8(x * scale);
|
||||
return float_to_fp8<quant_type_t>(x * scale);
|
||||
} else {
|
||||
return float_to_fp8(x / scale);
|
||||
return float_to_fp8<quant_type_t>(x / scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user