[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -15,31 +15,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
|
||||
float rms = 0.0f;
|
||||
float token_scale = 0.0f;
|
||||
|
||||
// Compute rms
|
||||
vllm::vectorized::compute_rms<scalar_t, has_residual>(
|
||||
&rms, input, hidden_size, var_epsilon, residual);
|
||||
&rms, input, hidden_size, input_stride, var_epsilon, residual);
|
||||
|
||||
// Compute scale
|
||||
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
||||
has_residual>(
|
||||
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||
residual);
|
||||
input_stride, residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
token_scale = 1.0f / token_scale;
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
|
||||
has_residual>(
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
has_residual>(out, input, weight, rms,
|
||||
&token_scale, hidden_size,
|
||||
input_stride, residual);
|
||||
} else {
|
||||
// FP8 - Do not invert token_scale for exact match with FBGemm
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
|
||||
has_residual>(
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
has_residual>(out, input, weight, rms,
|
||||
&token_scale, hidden_size,
|
||||
input_stride, residual);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,38 +53,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 4 == 0;
|
||||
bool const can_vectorize = hidden_size % 4 == 0 and input_stride % 4 == 0;
|
||||
|
||||
if (can_vectorize) {
|
||||
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
||||
has_residual>(
|
||||
out, scales, input, weight, scale_ub, var_epsilon, hidden_size,
|
||||
residual);
|
||||
input_stride, residual);
|
||||
}
|
||||
|
||||
float rms = 0.0f;
|
||||
float token_scale = 0.0f;
|
||||
|
||||
// Compute RMS
|
||||
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size,
|
||||
var_epsilon, residual);
|
||||
vllm::compute_rms<scalar_t, has_residual>(
|
||||
&rms, input, hidden_size, input_stride, var_epsilon, residual);
|
||||
// Compute Scale
|
||||
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
||||
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||
residual);
|
||||
input_stride, residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
token_scale = 1.0f / token_scale;
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, input_stride,
|
||||
residual);
|
||||
} else {
|
||||
// FP8 - Do not invert s_token_scale for exact match with FBGemm
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
|
||||
out, input, weight, rms, &token_scale, hidden_size, residual);
|
||||
out, input, weight, rms, &token_scale, hidden_size, input_stride,
|
||||
residual);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,19 +101,20 @@ __global__ void rms_norm_per_block_quant_kernel(
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
|
||||
int32_t const input_stride, scalar_t* __restrict__ residual = nullptr,
|
||||
int64_t outer_scale_stride = 1) {
|
||||
float rms;
|
||||
// Compute RMS
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
vllm::vectorized::compute_rms<scalar_t, has_residual>(
|
||||
&rms, input, hidden_size, var_epsilon, residual);
|
||||
&rms, input, hidden_size, input_stride, var_epsilon, residual);
|
||||
|
||||
// Compute Scale
|
||||
// Always able to vectorize due to constraints on hidden_size and group_size
|
||||
vllm::vectorized::compute_dynamic_per_token_scales<
|
||||
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
|
||||
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual,
|
||||
outer_scale_stride);
|
||||
nullptr, scales, input, weight, rms, scale_ub, hidden_size, input_stride,
|
||||
residual, outer_scale_stride);
|
||||
|
||||
// RMS Norm + Quant
|
||||
// Always able to vectorize due to constraints on hidden_size
|
||||
@@ -120,7 +125,7 @@ __global__ void rms_norm_per_block_quant_kernel(
|
||||
vllm::vectorized::norm_and_quant<
|
||||
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
|
||||
has_residual, is_scale_transposed, group_size>(
|
||||
out, input, weight, rms, scales, hidden_size, residual,
|
||||
out, input, weight, rms, scales, hidden_size, input_stride, residual,
|
||||
outer_scale_stride);
|
||||
}
|
||||
|
||||
@@ -137,6 +142,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
std::optional<at::Tensor> const& scale_ub,
|
||||
std::optional<at::Tensor>& residual) {
|
||||
int32_t hidden_size = input.size(-1);
|
||||
int32_t input_stride = input.view({-1, hidden_size}).stride(0);
|
||||
auto num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@@ -153,7 +159,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, hidden_size,
|
||||
var_epsilon, hidden_size, input_stride,
|
||||
has_residual ? residual->data_ptr<scalar_in_t>() : nullptr);
|
||||
});
|
||||
});
|
||||
@@ -170,7 +176,9 @@ void rms_norm_dynamic_per_token_quant(
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"Input must be contiguous in the last dimension");
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
@@ -179,6 +187,7 @@ void rms_norm_dynamic_per_token_quant(
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||
if (residual) {
|
||||
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
|
||||
TORCH_CHECK(residual->is_contiguous());
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
@@ -200,6 +209,15 @@ void rms_norm_per_block_quant_dispatch(
|
||||
std::optional<at::Tensor> const& scale_ub,
|
||||
std::optional<at::Tensor>& residual, bool is_scale_transposed) {
|
||||
int32_t hidden_size = input.size(-1);
|
||||
int32_t input_stride = input.view({-1, hidden_size}).stride(0);
|
||||
|
||||
TORCH_CHECK(hidden_size % 4 == 0,
|
||||
"Hidden size must be divisible by 4 for vectorized access");
|
||||
TORCH_CHECK(input_stride % 4 == 0,
|
||||
"Input stride must be divisible by 4 for vectorized access");
|
||||
TORCH_CHECK(group_size % 4 == 0,
|
||||
"Group size must be divisible by 4 for vectorized access");
|
||||
|
||||
auto num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@@ -225,7 +243,7 @@ void rms_norm_per_block_quant_dispatch(
|
||||
weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
var_epsilon, hidden_size,
|
||||
var_epsilon, hidden_size, input_stride,
|
||||
has_residual ? residual->data_ptr<scalar_in_t>()
|
||||
: nullptr,
|
||||
scales.stride(1));
|
||||
@@ -246,7 +264,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.stride(-1) == 1,
|
||||
"Input must be contiguous in the last dimension");
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
@@ -255,6 +275,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||
if (residual) {
|
||||
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
|
||||
TORCH_CHECK(residual->is_contiguous());
|
||||
}
|
||||
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
|
||||
@@ -16,14 +16,17 @@ namespace vllm {
|
||||
// has_residual must be true, if residual is not a nullptr
|
||||
template <typename scalar_t, bool has_residual = false>
|
||||
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
int32_t const hidden_size, float const epsilon,
|
||||
int32_t const hidden_size,
|
||||
int32_t const input_stride, float const epsilon,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const input_token_offset =
|
||||
blockIdx.x * static_cast<int64_t>(input_stride);
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
// sum of squares
|
||||
float ss = 0.0f;
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
float x = static_cast<float>(input[input_token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
@@ -73,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
|
||||
int32_t const hidden_size, int32_t const input_stride,
|
||||
scalar_t const* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
__syncthreads();
|
||||
|
||||
int64_t const input_token_offset =
|
||||
blockIdx.x * static_cast<int64_t>(input_stride);
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
if (group_size > 0) {
|
||||
__shared__ float s_max_vals[1024];
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
int64_t num_groups = hidden_size / group_size;
|
||||
__shared__ float s_max_vals[1024];
|
||||
int64_t const threads_per_group = blockDim.x / num_groups;
|
||||
int64_t const thread_in_group = threadIdx.x % threads_per_group;
|
||||
int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
|
||||
@@ -89,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
int64_t const thread_end =
|
||||
min(group_offset + group_size, static_cast<int64_t>(hidden_size));
|
||||
for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
float x = static_cast<float>(input[input_token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
@@ -144,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
float x = static_cast<float>(input[input_token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
@@ -185,12 +191,15 @@ template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
__device__ void norm_and_quant(
|
||||
scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight, float const rms, float* const scale,
|
||||
int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr,
|
||||
int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
|
||||
int32_t const hidden_size, int32_t const input_stride,
|
||||
scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0,
|
||||
int64_t outer_scale_stride = 1) {
|
||||
int64_t const input_token_offset =
|
||||
blockIdx.x * static_cast<int64_t>(input_stride);
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
float x = static_cast<float>(input[input_token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
residual[token_offset + i] = static_cast<scalar_t>(x);
|
||||
@@ -224,13 +233,16 @@ namespace vectorized {
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, bool has_residual = false>
|
||||
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
int32_t const hidden_size, float const epsilon,
|
||||
int32_t const hidden_size,
|
||||
int32_t const input_stride, float const epsilon,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const input_token_offset =
|
||||
blockIdx.x * static_cast<int64_t>(input_stride);
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
@@ -288,7 +300,8 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr,
|
||||
int32_t const hidden_size, int32_t const input_stride,
|
||||
scalar_t const* __restrict__ residual = nullptr,
|
||||
int64_t outer_scale_stride = 1) {
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
@@ -300,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
vec4_t<scalar_t> const* vec_weight = nullptr;
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
|
||||
int64_t const input_token_offset =
|
||||
blockIdx.x * static_cast<int64_t>(input_stride);
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
if constexpr (group_size > 0) {
|
||||
__shared__ float s_max_vals[1024];
|
||||
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
int64_t const num_groups = hidden_size / group_size;
|
||||
int64_t const threads_per_group = blockDim.x / num_groups;
|
||||
int64_t const thread_in_group = threadIdx.x % threads_per_group;
|
||||
@@ -312,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
int64_t const thread_offset = group_offset + thread_in_group;
|
||||
int64_t const thread_end = min(group_offset + (group_size >> 2),
|
||||
static_cast<int64_t>(hidden_size >> 2));
|
||||
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
|
||||
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
@@ -396,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
__syncthreads();
|
||||
|
||||
} else {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
|
||||
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
@@ -462,18 +479,18 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false, bool is_scale_transposed = false,
|
||||
int32_t group_size = 0>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float* const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr,
|
||||
int64_t outer_scale_stride = 1) {
|
||||
__device__ void norm_and_quant(
|
||||
scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight, float const rms, float* const scale,
|
||||
int32_t const hidden_size, int32_t const input_stride,
|
||||
scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
|
||||
int64_t const input_token_offset =
|
||||
blockIdx.x * static_cast<int64_t>(input_stride);
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
|
||||
vec4_t<scalar_t> const* vec_weight =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
q8x4_t<scalar_out_t>* vec_output =
|
||||
|
||||
Reference in New Issue
Block a user