[torch.compile] Dynamic fp8 + rms_norm fusion (#10906)
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
@@ -0,0 +1,160 @@
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "layernorm_utils.cuh"
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
__device__ void rms_norm_dynamic_per_token_quant_vec(
|
||||
scalar_out_t* __restrict__ out, // [..., hidden_size]
|
||||
float* __restrict__ scales, // [num_tokens]
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
float rms = 0.0f;
|
||||
float token_scale = 0.0f;
|
||||
|
||||
// Compute rms
|
||||
vllm::vectorized::compute_rms<scalar_t, has_residual>(
|
||||
&rms, input, hidden_size, var_epsilon, residual);
|
||||
|
||||
// Compute scale
|
||||
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
||||
has_residual>(
|
||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
||||
hidden_size, residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
|
||||
has_residual>(
|
||||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||
} else {
|
||||
// FP8 - Do not invert token_scale for exact match with FBGemm
|
||||
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
|
||||
has_residual>(
|
||||
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||
}
|
||||
}
|
||||
|
||||
// RMS norm + quant kernel
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
__global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||
scalar_out_t* __restrict__ out, // [..., hidden_size]
|
||||
float* __restrict__ scales, // [num_tokens]
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 4 == 0;
|
||||
|
||||
if (can_vectorize) {
|
||||
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
||||
has_residual>(
|
||||
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
|
||||
hidden_size, residual);
|
||||
}
|
||||
|
||||
float rms = 0.0f;
|
||||
float token_scale = 0.0f;
|
||||
|
||||
// Compute RMS
|
||||
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size,
|
||||
var_epsilon, residual);
|
||||
// Compute Scale
|
||||
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
||||
hidden_size, residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
|
||||
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||
} else {
|
||||
// FP8 - Do not invert s_token_scale for exact match with FBGemm
|
||||
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
|
||||
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
// Residual add + RMS norm + dynamic per token
|
||||
template <typename scalar_in_t>
|
||||
void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& weight, // [hidden_size]
|
||||
torch::Tensor& scales, // [num_tokens]
|
||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||
std::optional<at::Tensor> const& scale_ub,
|
||||
std::optional<at::Tensor>& residual) {
|
||||
int32_t hidden_size = input.size(-1);
|
||||
int32_t num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const float min_scaling_factor =
|
||||
out.dtype() == torch::kInt8
|
||||
? std::numeric_limits<float>::epsilon()
|
||||
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);
|
||||
|
||||
if (residual.has_value()) {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||
true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, min_scaling_factor, hidden_size,
|
||||
residual->data_ptr<scalar_in_t>());
|
||||
});
|
||||
|
||||
} else {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||
false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, min_scaling_factor, hidden_size, nullptr);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void rms_norm_dynamic_per_token_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& weight, // [hidden_size]
|
||||
torch::Tensor& scales, // [num_tokens]
|
||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
}
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
|
||||
rms_norm_dynamic_per_token_quant_dispatch<scalar_t>(
|
||||
out, input, weight, scales, var_epsilon, scale_ub, residual);
|
||||
});
|
||||
}
|
||||
327
csrc/quantization/fused_kernels/layernorm_utils.cuh
Normal file
327
csrc/quantization/fused_kernels/layernorm_utils.cuh
Normal file
@@ -0,0 +1,327 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* __device__ layernorm utilities.
|
||||
*/
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// has_residual must be true, if residual is not a nullptr
|
||||
template <typename scalar_t, bool has_residual = false>
|
||||
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
int32_t const hidden_size, float const epsilon,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
// sum of squares
|
||||
float ss = 0.0f;
|
||||
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
|
||||
ss += x * x;
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
|
||||
|
||||
__shared__ float s_rms;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms = rsqrtf(ss / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
*rms = s_rms;
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
__device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
}
|
||||
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor);
|
||||
s_token_scale = scale; // Shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
residual[token_offset + i] = static_cast<scalar_t>(x);
|
||||
}
|
||||
// Norm
|
||||
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||
// Quant
|
||||
output[token_offset + i] =
|
||||
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vectorized {
|
||||
|
||||
// Compute 1.0/rms(input)
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, bool has_residual = false>
|
||||
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
int32_t const hidden_size, float const epsilon,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
// sum of squares
|
||||
float ss = 0.0f;
|
||||
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
x.x = static_cast<float>(in.x);
|
||||
x.y = static_cast<float>(in.y);
|
||||
x.z = static_cast<float>(in.z);
|
||||
x.w = static_cast<float>(in.w);
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
x.x += static_cast<float>(r.x);
|
||||
x.y += static_cast<float>(r.y);
|
||||
x.z += static_cast<float>(r.z);
|
||||
x.w += static_cast<float>(r.w);
|
||||
}
|
||||
|
||||
ss += x.x * x.x;
|
||||
ss += x.y * x.y;
|
||||
ss += x.z * x.z;
|
||||
ss += x.w * x.w;
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
|
||||
|
||||
__shared__ float s_rms;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms = rsqrtf(ss / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
*rms = s_rms;
|
||||
}
|
||||
|
||||
// Vectorized version of vllm::compute_dynamic_per_token_scales
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||
__device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
// Vectorized input/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec4_t<scalar_t> const* vec_weight =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||
if constexpr (has_residual) {
|
||||
vec_residual =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
x.x = static_cast<float>(in.x);
|
||||
x.y = static_cast<float>(in.y);
|
||||
x.z = static_cast<float>(in.z);
|
||||
x.w = static_cast<float>(in.w);
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
x.x += static_cast<float>(r.x);
|
||||
x.y += static_cast<float>(r.y);
|
||||
x.z += static_cast<float>(r.z);
|
||||
x.w += static_cast<float>(r.w);
|
||||
}
|
||||
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.x * rms) * w.x));
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.y * rms) * w.y));
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.z * rms) * w.z));
|
||||
block_absmax_val_maybe = fmaxf(
|
||||
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.w * rms) * w.w));
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = 0.0f;
|
||||
if (scale_ub) {
|
||||
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor);
|
||||
s_token_scale = scale; // shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // global output store
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
*token_scale = s_token_scale;
|
||||
}
|
||||
|
||||
// hidden_size must be a multiple of 4
|
||||
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||
bool has_residual = false>
|
||||
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
scalar_t const* __restrict__ weight,
|
||||
float const rms, float const scale,
|
||||
int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vec_input =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||
vec4_t<scalar_t> const* vec_weight =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||
q8x4_t<scalar_out_t>* vec_output =
|
||||
reinterpret_cast<q8x4_t<scalar_out_t>*>(&output[token_offset]);
|
||||
vec4_t<scalar_t>* vec_residual = nullptr;
|
||||
if constexpr (has_residual) {
|
||||
vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
|
||||
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
|
||||
// replace scaled_fp8_conversion_vec
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> const in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
x.x = static_cast<float>(in.x);
|
||||
x.y = static_cast<float>(in.y);
|
||||
x.z = static_cast<float>(in.z);
|
||||
x.w = static_cast<float>(in.w);
|
||||
if constexpr (has_residual) {
|
||||
vec4_t<scalar_t> r = vec_residual[i];
|
||||
x.x += static_cast<float>(r.x);
|
||||
x.y += static_cast<float>(r.y);
|
||||
x.z += static_cast<float>(r.z);
|
||||
x.w += static_cast<float>(r.w);
|
||||
// Update residual
|
||||
r.x = static_cast<scalar_t>(x.x);
|
||||
r.y = static_cast<scalar_t>(x.y);
|
||||
r.z = static_cast<scalar_t>(x.z);
|
||||
r.w = static_cast<scalar_t>(x.w);
|
||||
vec_residual[i] = r;
|
||||
}
|
||||
|
||||
q8x4_t<scalar_out_t> out;
|
||||
out.x = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.x * rms) * w.x, scale);
|
||||
out.y = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.y * rms) * w.y, scale);
|
||||
out.z = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.z * rms) * w.z, scale);
|
||||
out.w = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||
static_cast<scalar_t>(x.w * rms) * w.w, scale);
|
||||
vec_output[i] = out;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vectorized
|
||||
|
||||
} // namespace vllm
|
||||
81
csrc/quantization/fused_kernels/quant_conversions.cuh
Normal file
81
csrc/quantization/fused_kernels/quant_conversions.cuh
Normal file
@@ -0,0 +1,81 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* __device__ helper functions to deal with float -> quant datatype conversion
|
||||
*/
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
// TODO(luka/varun):refactor common.cuh to use this file instead
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// TODO(luka/varun): combine into common utilities for int8
|
||||
// (with int8_quant_kernels.cu)
|
||||
static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
||||
#ifdef USE_ROCM
|
||||
static const float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
static const float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
// round
|
||||
float dst = std::nearbyint(x);
|
||||
// saturate
|
||||
dst = std::clamp(dst, i8_min, i8_max);
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
// CUDA path
|
||||
uint32_t dst;
|
||||
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
||||
return reinterpret_cast<const int8_t&>(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
|
||||
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<FP8_TYPE>(r);
|
||||
}
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
|
||||
struct ScaledQuant;
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted>
|
||||
struct ScaledQuant<
|
||||
quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<std::is_same_v<quant_type_t, int8_t>>> {
|
||||
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||
float const scale) {
|
||||
if constexpr (is_scale_inverted) {
|
||||
return float_to_int8_rn(x * scale);
|
||||
} else {
|
||||
return float_to_int8_rn(x / scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted>
|
||||
struct ScaledQuant<
|
||||
quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
|
||||
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||
float const scale) {
|
||||
if constexpr (is_scale_inverted) {
|
||||
return float_to_fp8(x * scale);
|
||||
} else {
|
||||
return float_to_fp8(x / scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, typename quant_type_t, bool is_scale_inverted>
|
||||
__device__ void scaled_quant_conversion(quant_type_t* __restrict__ output,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale, int const tid,
|
||||
int const num_elements,
|
||||
int const step) {
|
||||
for (int i = tid; i < num_elements; i += step) {
|
||||
output[i] = ScaledQuant<quant_type_t, is_scale_inverted>(input[i], scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
Reference in New Issue
Block a user