[Bugfix] Fix quant RMS norm fusion for quantization with TMA-aligned scales (#33255)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
ElizaWszola
2026-02-18 08:35:04 +01:00
committed by GitHub
parent a49ea5a58f
commit a88b3be7c4
12 changed files with 234 additions and 75 deletions

View File

@@ -74,7 +74,7 @@ __device__ void compute_dynamic_per_token_scales(
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
int32_t const group_size = 0) {
int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
float block_absmax_val_maybe = 0.0f;
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
__syncthreads();
@@ -133,7 +133,9 @@ __device__ void compute_dynamic_per_token_scales(
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store
if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
blockIdx.x] = scale;
} else {
all_token_scales[blockIdx.x * num_groups +
@@ -180,13 +182,11 @@ __device__ void compute_dynamic_per_token_scales(
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false, bool is_scale_transposed = false>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight,
float const rms, float* const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr,
int32_t const group_size = 0) {
__device__ void norm_and_quant(
scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight, float const rms, float* const scale,
int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr,
int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
@@ -202,7 +202,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t scale_idx = 0;
if (group_size > 0) {
if constexpr (is_scale_transposed) {
scale_idx = (i / group_size) * gridDim.x + blockIdx.x;
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
scale_idx = (i / group_size) * scale_rows + blockIdx.x;
} else {
scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
}
@@ -286,8 +288,8 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) {
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
const int VEC_SIZE = 4;
@@ -382,7 +384,9 @@ __device__ void compute_dynamic_per_token_scales(
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store
if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
blockIdx.x] = scale;
} else {
all_token_scales[blockIdx.x * num_groups +
@@ -463,7 +467,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
scalar_t const* __restrict__ weight,
float const rms, float* const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
scalar_t* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
@@ -516,7 +521,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t const num_groups = hidden_size / group_size;
int64_t scale_idx = 0;
if constexpr (is_scale_transposed) {
scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x;
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
scale_idx = (i * VEC_SIZE / group_size) * scale_rows + blockIdx.x;
} else {
scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
}