[Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -30,9 +30,6 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor =
|
||||
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
|
||||
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
@@ -67,8 +64,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
min_scaling_factor);
|
||||
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
|
||||
min_scaling_factor<fp8_type>::val());
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Reference in New Issue
Block a user