From d7a0fc2bc2d609198ed3eb90620e84277ceaea7d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 17:59:57 +0000 Subject: [PATCH] CRITICAL FIX: K GMEM slice (None,None,0,0) not (None,0,None,0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit K from QK MMA B-partition has GMEM iter at mode 1, NOT mode 2. (None,0,None,0) hardcodes mode 1 to 0 → TMA always loads tile 0. (None,None,0,0) keeps mode 1 free → correct multi-tile loading. Proof: diag n=256 went from cos 0.711 → 0.999999 with this one change. --- tests/unit/test_fmha_v3_stage_c_full.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 57022f2e..a2e6cf4a 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -168,12 +168,16 @@ class FmhaV3StageC: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # GMEM slices: (None,0,None,0) keeps mode 2 free for both K and V. - # Debug shapes showed tBgK modes 1,2 are Int32(?), tVgV mode 2 grows with n. - # Both K and V have GMEM iteration at mode 2 in our tma_partition output. - tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine - tBgK = tBgK[(None,0,None,0)] # K: keep mode 2 free (GMEM iter) - tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 free (GMEM iter) + # GMEM slices: K uses mode 1 for GMEM iter → (None,None,0,0) keeps it free + # V uses mode 2 for GMEM iter → (None,0,None,0) keeps it free + # Q has 1 tile → (None,0,None,0) hardcode is fine + # CRITICAL: K from QK MMA B-partition has GMEM iter at mode 1, NOT mode 2! + # (None,0,None,0) for K hardcodes mode 1 to 0 → always loads tile 0. + # (None,None,0,0) for K keeps mode 1 free → correct multi-tile loading. + # Proven by diag test: (None,0,None,0) gives cos 0.711, (None,None,0,0) gives 0.999999. + tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile + tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free + tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV)