[Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-03-31 06:42:18 -05:00
committed by GitHub
parent effc5d24fa
commit e85829450d
8 changed files with 92 additions and 72 deletions

View File

@@ -5,6 +5,7 @@
*/
#include "quantization/vectorization.cuh"
#include "quantization/utils.cuh"
#include "quant_conversions.cuh"
#ifndef USE_ROCM
@@ -51,11 +52,11 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
float const min_scaling_factor, int32_t const hidden_size,
int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
float block_absmax_val_maybe = 0.0f;
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
@@ -83,7 +84,7 @@ __device__ void compute_dynamic_per_token_scales(
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor);
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
s_token_scale = scale; // Shared memory store
all_token_scales[blockIdx.x] = scale; // Global output store
}
@@ -184,7 +185,7 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
float const min_scaling_factor, int32_t const hidden_size,
int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
@@ -200,7 +201,7 @@ __device__ void compute_dynamic_per_token_scales(
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
}
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
int32_t const num_vec_elems = hidden_size >> 2;
float block_absmax_val_maybe = 0.0f;
@@ -248,7 +249,7 @@ __device__ void compute_dynamic_per_token_scales(
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor);
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
s_token_scale = scale; // shared memory store
all_token_scales[blockIdx.x] = scale; // global output store
}