[Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -14,8 +14,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
||||
float* __restrict__ scales, // [num_tokens]
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
float rms = 0.0f;
|
||||
float token_scale = 0.0f;
|
||||
@@ -27,8 +26,8 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
||||
// Compute scale
|
||||
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
||||
has_residual>(
|
||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
||||
hidden_size, residual);
|
||||
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||
residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
@@ -50,8 +49,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||
float* __restrict__ scales, // [num_tokens]
|
||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||
float const* scale_ub, float const var_epsilon,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||
scalar_t* __restrict__ residual = nullptr) {
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
@@ -60,8 +58,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||
if (can_vectorize) {
|
||||
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
||||
has_residual>(
|
||||
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
|
||||
hidden_size, residual);
|
||||
out, scales, input, weight, scale_ub, var_epsilon, hidden_size,
|
||||
residual);
|
||||
}
|
||||
|
||||
float rms = 0.0f;
|
||||
@@ -72,8 +70,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||
var_epsilon, residual);
|
||||
// Compute Scale
|
||||
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
||||
hidden_size, residual);
|
||||
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||
residual);
|
||||
|
||||
// RMS Norm + Quant
|
||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||
@@ -105,11 +103,6 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const float min_scaling_factor =
|
||||
out.dtype() == torch::kInt8
|
||||
? std::numeric_limits<float>::epsilon()
|
||||
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);
|
||||
|
||||
if (residual.has_value()) {
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||
@@ -119,8 +112,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, min_scaling_factor, hidden_size,
|
||||
residual->data_ptr<scalar_in_t>());
|
||||
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>());
|
||||
});
|
||||
|
||||
} else {
|
||||
@@ -132,7 +124,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
var_epsilon, min_scaling_factor, hidden_size, nullptr);
|
||||
var_epsilon, hidden_size, nullptr);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
#include "quantization/vectorization.cuh"
|
||||
#include "quantization/utils.cuh"
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
@@ -51,11 +52,11 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
@@ -83,7 +84,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor);
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // Shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||
}
|
||||
@@ -184,7 +185,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||
float const rms, float const* __restrict__ scale_ub,
|
||||
float const min_scaling_factor, int32_t const hidden_size,
|
||||
int32_t const hidden_size,
|
||||
scalar_t const* __restrict__ residual = nullptr) {
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
@@ -200,7 +201,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||
}
|
||||
|
||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
@@ -248,7 +249,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
scale = max(scale / qmax, min_scaling_factor);
|
||||
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||
s_token_scale = scale; // shared memory store
|
||||
all_token_scales[blockIdx.x] = scale; // global output store
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
||||
|
||||
template <typename fp8_type>
|
||||
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
||||
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
float const r =
|
||||
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||
return static_cast<fp8_type>(r);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user