diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh index 043f6fb6..0182a613 100644 --- a/dsv4/kernels/attention/fmha_6warp_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -112,20 +112,22 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { if (is_load_warp) { constexpr int CORES_MN = 128 / 8; for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; + // Q loading: write to same canonical positions for each K-tile + // The UMMA descriptor always reads from sQ0 start for (int r = 0; r < T; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { - int full_d = kt * MMA_K_BF16 + d; - int core_k = full_d / 8, local_c = full_d % 8; + int full_d = kt * MMA_K_BF16 + d; // GMEM index + int ck = d / 8, lc = d % 8; // canonical position (same for all kt) int core_mn = r / 8, local_r = r % 8; - sQ0[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = + sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = q_head[r * HD + full_d]; } } for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0; for (int r = 0; r < s_k; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { - int full_d = kt * MMA_K_BF16 + d; - int ck = full_d / 8, lc = full_d % 8; + int full_d = kt * MMA_K_BF16 + d; // GMEM index + int ck = d / 8, lc = d % 8; // canonical position (same for all kt) int tmn = r / 8, lr = r % 8; sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + full_d]; }