[Bugfix] Fix quant RMS norm fusion for quantization with TMA-aligned scales (#33255)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -315,7 +315,9 @@ void silu_and_mul_scaled_fp4_experts_quant(
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0);
|
||||
double fp8_max, bool scale_ue8m0,
|
||||
bool dummy_is_scale_transposed,
|
||||
bool dummy_is_tma_aligned);
|
||||
|
||||
void per_token_group_quant_int8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
|
||||
@@ -97,7 +97,7 @@ __global__ void rms_norm_per_block_quant_kernel(
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
|
||||
float rms;
|
||||
// Compute RMS
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
@@ -108,7 +108,8 @@ __global__ void rms_norm_per_block_quant_kernel(
|
||||
// Always able to vectorize due to constraints on hidden_size and group_size
|
||||
vllm::vectorized::compute_dynamic_per_token_scales<
|
||||
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
|
||||
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual);
|
||||
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual,
|
||||
outer_scale_stride);
|
||||
|
||||
// RMS Norm + Quant
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
@@ -119,7 +120,8 @@ __global__ void rms_norm_per_block_quant_kernel(
|
||||
vllm::vectorized::norm_and_quant<
|
||||
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
|
||||
has_residual, is_scale_transposed, group_size>(
|
||||
out, input, weight, rms, scales, hidden_size, residual);
|
||||
out, input, weight, rms, scales, hidden_size, residual,
|
||||
outer_scale_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -225,7 +227,8 @@ void rms_norm_per_block_quant_dispatch(
|
||||
: nullptr,
|
||||
var_epsilon, hidden_size,
|
||||
has_residual ? residual->data_ptr<scalar_in_t>()
|
||||
: nullptr);
|
||||
: nullptr,
|
||||
scales.stride(1));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -257,6 +260,11 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
"Unsupported group size: ", group_size);
|
||||
|
||||
if (scales.stride(1) > 1) {
|
||||
TORCH_CHECK(is_scale_transposed,
|
||||
"Outer scale stride must be 1 when scales are not transposed");
|
||||
}
|
||||
|
||||
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
|
||||
var_epsilon, scale_ub, residual,
|
||||
is_scale_transposed);
|
||||
|
||||
@@ -74,7 +74,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0) {
|
||||
int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
__syncthreads();
|
||||
@@ -133,7 +133,9 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
// Global output store
|
||||
if constexpr (is_scale_transposed) {
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
|
||||
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
|
||||
outer_scale_stride * outer_scale_stride;
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
|
||||
blockIdx.x] = scale;
|
||||
} else {
|
||||
all_token_scales[blockIdx.x * num_groups +
|
||||
@@ -180,13 +182,11 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false, bool is_scale_transposed = false>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float* const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0) {
|
||||
__device__ void norm_and_quant(
|
||||
scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight, float const rms, float* const scale,
|
||||
int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
@@ -202,7 +202,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
int64_t scale_idx = 0;
|
||||
if (group_size > 0) {
|
||||
if constexpr (is_scale_transposed) {
|
||||
scale_idx = (i / group_size) * gridDim.x + blockIdx.x;
|
||||
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
|
||||
outer_scale_stride * outer_scale_stride;
|
||||
scale_idx = (i / group_size) * scale_rows + blockIdx.x;
|
||||
} else {
|
||||
scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
|
||||
}
|
||||
@@ -286,8 +288,8 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
|
||||
int64_t outer_scale_stride = 1) {
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
const int VEC_SIZE = 4;
|
||||
@@ -382,7 +384,9 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
// Global output store
|
||||
if constexpr (is_scale_transposed) {
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x +
|
||||
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
|
||||
outer_scale_stride * outer_scale_stride;
|
||||
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
|
||||
blockIdx.x] = scale;
|
||||
} else {
|
||||
all_token_scales[blockIdx.x * num_groups +
|
||||
@@ -463,7 +467,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float* const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
scalar_t* __restrict__ residual = nullptr,
|
||||
int64_t outer_scale_stride = 1) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
|
||||
@@ -516,7 +521,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
int64_t const num_groups = hidden_size / group_size;
|
||||
int64_t scale_idx = 0;
|
||||
if constexpr (is_scale_transposed) {
|
||||
scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x;
|
||||
int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
|
||||
outer_scale_stride * outer_scale_stride;
|
||||
scale_idx = (i * VEC_SIZE / group_size) * scale_rows + blockIdx.x;
|
||||
} else {
|
||||
scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
|
||||
}
|
||||
|
||||
@@ -379,7 +379,9 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0) {
|
||||
double fp8_max, bool scale_ue8m0,
|
||||
bool dummy_is_scale_transposed = false,
|
||||
bool dummy_is_tma_aligned = false) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
||||
@@ -643,11 +643,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
||||
// The dummy arguments are here so we can correctly fuse with RMSNorm.
|
||||
ops.def(
|
||||
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, "
|
||||
"int group_size, float eps, float fp8_min, float fp8_max, bool "
|
||||
"scale_ue8m0) -> ()");
|
||||
"scale_ue8m0, bool dummy_is_scale_transposed, bool dummy_is_tma_aligned "
|
||||
") -> ()");
|
||||
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
||||
&per_token_group_quant_fp8);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user