CRITICAL FIX: Q/K SMEM canonical layout must use local d (0..15) not full_d — UMMA descriptor reads from sQ0/sK0 start, not offset

This commit is contained in:
2026-05-28 20:13:52 +00:00
parent 08694b8136
commit deaa3ec725

View File

@@ -112,20 +112,22 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
if (is_load_warp) {
constexpr int CORES_MN = 128 / 8;
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
// Q loading: write to same canonical positions for each K-tile
// The UMMA descriptor always reads from sQ0 start
for (int r = 0; r < T; r++) {
for (int d = lane; d < MMA_K_BF16; d += 32) {
int full_d = kt * MMA_K_BF16 + d;
int core_k = full_d / 8, local_c = full_d % 8;
int full_d = kt * MMA_K_BF16 + d; // GMEM index
int ck = d / 8, lc = d % 8; // canonical position (same for all kt)
int core_mn = r / 8, local_r = r % 8;
sQ0[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] =
sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] =
q_head[r * HD + full_d];
}
}
for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0;
for (int r = 0; r < s_k; r++) {
for (int d = lane; d < MMA_K_BF16; d += 32) {
int full_d = kt * MMA_K_BF16 + d;
int ck = full_d / 8, lc = full_d % 8;
int full_d = kt * MMA_K_BF16 + d; // GMEM index
int ck = d / 8, lc = d % 8; // canonical position (same for all kt)
int tmn = r / 8, lr = r % 8;
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + full_d];
}