[AMD][Kernel][BugFix] Use correct scale in concat_and_cache_ds_mla_kernel when on gfx942 (#32976)
Signed-off-by: Randall Smith <ransmith@amd.com> Signed-off-by: Randall Smith <Randall.Smith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -24,6 +24,12 @@
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#if defined(__gfx942__)
|
||||
constexpr float kFp8ScaleDivisor = 224.f;
|
||||
#else
|
||||
constexpr float kFp8ScaleDivisor = 448.f;
|
||||
#endif
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
int64_t block_size_in_bytes,
|
||||
const torch::Tensor& block_mapping) {
|
||||
@@ -401,8 +407,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
|
||||
}
|
||||
|
||||
// Compute the scale for the tile
|
||||
float tile_scale = max_abs / 448.f;
|
||||
tile_scale = fmaxf(tile_scale, FLT_MIN);
|
||||
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisor, FLT_MIN);
|
||||
|
||||
// The first lane of each half-warp writes the scale to kv_cache
|
||||
if ((lane_idx == 0) || (lane_idx == 16)) {
|
||||
@@ -471,11 +476,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx942__)
|
||||
float scale = fmaxf(amax, 1e-4) / 224.0f;
|
||||
#else
|
||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||
#endif
|
||||
float scale = fmaxf(amax, 1e-4) / kFp8ScaleDivisor;
|
||||
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user