From 6db7fd339d9cb8bf5dd933526f50f3dcacd03da7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 23:35:55 +0000 Subject: [PATCH] =?UTF-8?q?FIX:=20(None,0,None,0)=20for=20ALL=20tma=5Fpart?= =?UTF-8?q?ition=20outputs=20=E2=80=94=20verified=20shapes=20on=20B200?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DIAG OUTPUT (n=256, inside @cute.kernel): tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes tVgV: (((64,128),1), 1, 1, 1) — 4 modes After (None,0,None,0) → keeps modes 0 and 2 free → 2D: tAgQ: (((64,128),1), Int32(?)) tBgK: (((64,128),1), Int32(?)) tVgV: (((64,128),1), 1) Then [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles). tAgQ[(None, Int32(0))] for Q (1 tile, coordinate is always 0). Removed diag prints from test_fmha_v3.py. --- tests/fmha_v3_stage_c_example10.py | 33 +++++++++++++++--------------- tests/unit/test_fmha_v3.py | 9 -------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/tests/fmha_v3_stage_c_example10.py b/tests/fmha_v3_stage_c_example10.py index 4c6ef459..14b12f0e 100644 --- a/tests/fmha_v3_stage_c_example10.py +++ b/tests/fmha_v3_stage_c_example10.py @@ -14,14 +14,14 @@ Three structural rules learned the hard way: `utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what the CUTLASS Blackwell FMHA reference uses in `correction_rescale`. -(C) tma_partition produces a 4-mode tensor: (V_grouped, ?, KV_tiles, ?). - Mode 2 is the GMEM-tile iteration axis. Pre-slicing with - `tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets mode 2 to 0, - collapsing the KV-tile axis so TMA always reads tile 0. +(C) tma_partition produces 4-mode tensors: (((64,128),1), ?, ?, ?). + Mode 2 is the GMEM-tile iteration axis. + Pre-slicing with `tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets + mode 2 to 0, collapsing the KV-tile axis so TMA always reads tile 0. Fix: pre-slice with `tBgK[(None,0,None,0)]` to keep modes 0,2 free, - then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes mode 1 (the - surviving KV_tiles mode from the original mode 2) with kt. - This matches the CUTLASS reference FMHA pattern. + then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes the surviving + KV_tiles mode. Verified shapes on B200: after (None,0,None,0), all + three tensors become 2D: (((64,128),1), Int32(?)) or (((64,128),1), 1). Kernel structure: @@ -181,14 +181,15 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # NOTE: after tma_partition, tAgQ has 2 modes (Q has 1 tile), - # tBgK/tVgV have 4 modes: (V_grouped, ?, KV_tiles, ?). - # Mode 2 of tBgK/tVgV is the GMEM tile iteration axis. - # The old pre-slice (None,None,0,0) kept modes 0,1 free and set - # mode 2 to 0, collapsing the KV tile axis. - # Fix: (None,0,None,0) keeps modes 0,2 free, then [None, kt] - # indexes the surviving KV_tiles dim in cute.copy. - tAgQ = tAgQ[(None,0)] + # SHAPES (from diag test on B200, n=256): + # tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes + # tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes + # tVgV: (((64,128),1), 1, 1, 1) — 4 modes + # Mode 2 is the GMEM tile iteration axis. + # (None,0,None,0) keeps modes 0 and 2 free → 2D tensor. + # Then [None, kt] indexes the surviving KV_tiles dim. + # The old (None,None,0,0) kept modes 0,1 free → mode 2 (KV tiles) set to 0. + tAgQ = tAgQ[(None,0,None,0)] tBgK = tBgK[(None,0,None,0)] tVgV = tVgV[(None,0,None,0)] @@ -220,7 +221,7 @@ class FmhaV3StageCMulti: # Matches the CUTLASS reference FMHA TMA indexing pattern. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[0], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + cute.copy(tma_q, tAgQ[None, Int32(0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset() for kt in cutlass.range_constexpr(self.n_kv_tiles): diff --git a/tests/unit/test_fmha_v3.py b/tests/unit/test_fmha_v3.py index e65cc2aa..69a7ab33 100644 --- a/tests/unit/test_fmha_v3.py +++ b/tests/unit/test_fmha_v3.py @@ -134,16 +134,7 @@ class FmhaV3: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - print(f"DIAG tAgQ: shape={cute.shape(tAgQ)} stride={tAgQ.layout.stride}") - print(f"DIAG tBgK: shape={cute.shape(tBgK)} stride={tBgK.layout.stride}") - print(f"DIAG tVgV: shape={cute.shape(tVgV)} stride={tVgV.layout.stride}") - print(f"DIAG tAsQ: shape={cute.shape(tAsQ)} stride={tAsQ.layout.stride}") - print(f"DIAG tBsK: shape={cute.shape(tBsK)} stride={tBsK.layout.stride}") - print(f"DIAG tVsV: shape={cute.shape(tVsV)} stride={tVsV.layout.stride}") tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] - print(f"DIAG tAgQ sliced (None,0,None,0): shape={cute.shape(tAgQ)} stride={tAgQ.layout.stride}") - print(f"DIAG tBgK sliced (None,0,None,0): shape={cute.shape(tBgK)} stride={tBgK.layout.stride}") - print(f"DIAG tVgV sliced (None,0,None,0): shape={cute.shape(tVgV)} stride={tVgV.layout.stride}") tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV)