From deaa3ec7258fa39aedb6f411d92a139741af0fdf Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 20:13:52 +0000 Subject: [PATCH] =?UTF-8?q?CRITICAL=20FIX:=20Q/K=20SMEM=20canonical=20layo?= =?UTF-8?q?ut=20must=20use=20local=20d=20(0..15)=20not=20full=5Fd=20?= =?UTF-8?q?=E2=80=94=20UMMA=20descriptor=20reads=20from=20sQ0/sK0=20start,?= =?UTF-8?q?=20not=20offset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha_6warp_multirow.cuh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh index 043f6fb6..0182a613 100644 --- a/dsv4/kernels/attention/fmha_6warp_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -112,20 +112,22 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { if (is_load_warp) { constexpr int CORES_MN = 128 / 8; for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; + // Q loading: write to same canonical positions for each K-tile + // The UMMA descriptor always reads from sQ0 start for (int r = 0; r < T; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { - int full_d = kt * MMA_K_BF16 + d; - int core_k = full_d / 8, local_c = full_d % 8; + int full_d = kt * MMA_K_BF16 + d; // GMEM index + int ck = d / 8, lc = d % 8; // canonical position (same for all kt) int core_mn = r / 8, local_r = r % 8; - sQ0[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = + sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = q_head[r * HD + full_d]; } } for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0; for (int r = 0; r < s_k; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { - int full_d = kt * MMA_K_BF16 + d; - int ck = full_d / 8, lc = full_d % 8; + int full_d = kt * MMA_K_BF16 + d; // GMEM index + int ck = d / 8, lc = d % 8; // canonical position (same for all kt) int tmn = r / 8, lr = r % 8; sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + full_d]; }