CRITICAL FIX: Q/K SMEM canonical layout must use local d (0..15) not full_d — UMMA descriptor reads from sQ0/sK0 start, not offset
This commit is contained in:
@@ -112,20 +112,22 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
|
||||
if (is_load_warp) {
|
||||
constexpr int CORES_MN = 128 / 8;
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
|
||||
// Q loading: write to same canonical positions for each K-tile
|
||||
// The UMMA descriptor always reads from sQ0 start
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
int core_k = full_d / 8, local_c = full_d % 8;
|
||||
int full_d = kt * MMA_K_BF16 + d; // GMEM index
|
||||
int ck = d / 8, lc = d % 8; // canonical position (same for all kt)
|
||||
int core_mn = r / 8, local_r = r % 8;
|
||||
sQ0[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] =
|
||||
sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] =
|
||||
q_head[r * HD + full_d];
|
||||
}
|
||||
}
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0;
|
||||
for (int r = 0; r < s_k; r++) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
int ck = full_d / 8, lc = full_d % 8;
|
||||
int full_d = kt * MMA_K_BF16 + d; // GMEM index
|
||||
int ck = d / 8, lc = d % 8; // canonical position (same for all kt)
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + full_d];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user