[torch.compile] Dynamic fp8 + rms_norm fusion (#10906)

Signed-off-by: luka <luka@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Luka Govedič
2024-12-12 22:19:23 -05:00
committed by GitHub
parent 78ed8f57d8
commit 30870b4f66
20 changed files with 1735 additions and 251 deletions

View File

@@ -1,6 +1,9 @@
#pragma once
#include "quantization/vectorization.cuh"
#include <cmath>
#include <c10/core/ScalarType.h>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
@@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
namespace vllm {
@@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
}
}
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
@@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
using float8x4_t = q8x4_t<FP8_TYPE>;
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;