[Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-03-31 06:42:18 -05:00
committed by GitHub
parent effc5d24fa
commit e85829450d
8 changed files with 92 additions and 72 deletions

View File

@@ -30,9 +30,6 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor =
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
@@ -67,8 +64,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
min_scaling_factor);
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
min_scaling_factor<fp8_type>::val());
scale[token_idx] = token_scale;
}
__syncthreads();