[AMD][Kernel][BugFix] Use correct scale in concat_and_cache_ds_mla_kernel when on gfx942 (#32976)

Signed-off-by: Randall Smith <ransmith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2026-01-27 01:16:43 -06:00
committed by GitHub
parent b539f988e1
commit 58996f3589

View File

@@ -24,6 +24,12 @@
typedef __hip_bfloat16 __nv_bfloat16;
#endif
#if defined(__gfx942__)
constexpr float kFp8ScaleDivisor = 224.f;
#else
constexpr float kFp8ScaleDivisor = 448.f;
#endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping) {
@@ -401,8 +407,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
}
// Compute the scale for the tile
float tile_scale = max_abs / 448.f;
tile_scale = fmaxf(tile_scale, FLT_MIN);
float tile_scale = fmaxf(max_abs / kFp8ScaleDivisor, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) {
@@ -471,11 +476,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
#endif
}
#if defined(__gfx942__)
float scale = fmaxf(amax, 1e-4) / 224.0f;
#else
float scale = fmaxf(amax, 1e-4) / 448.0f;
#endif
float scale = fmaxf(amax, 1e-4) / kFp8ScaleDivisor;
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}