diff --git a/tests/fmha_v3_stage_c_example10.py b/tests/fmha_v3_stage_c_example10.py index f0614abb..2821b5ef 100644 --- a/tests/fmha_v3_stage_c_example10.py +++ b/tests/fmha_v3_stage_c_example10.py @@ -14,18 +14,14 @@ Three structural rules learned the hard way: `utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what the CUTLASS Blackwell FMHA reference uses in `correction_rescale`. -(C) tma_partition produces a tensor with 8 TMA coordinate dimensions, but only - 4 are visible in the Python shape. After - tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, - group_modes(sK,0,3), - group_modes(tCgK,0,3)) - `tBgK` has 8 TMA coord modes: (1,1,1,1,n_kv_tiles,1,1,1). - Mode 4 is the GMEM-tile iteration axis. - Pre-slicing with `tBgK[(None,None,0,0)]` collapses the GMEM tile axis to - coordinate 0, so every TMA copy reads tile 0 regardless of the coord - value passed. The 8-None no-op pre-slice opens the full TMA coord space. - Fix: tBgK = tBgK[(None,None,None,None,None,None,None,None)], then - cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...) +(C) tma_partition produces a 4-mode tensor: (V_grouped, ?, KV_tiles, ?). + Mode 2 is the GMEM-tile iteration axis. Pre-slicing with + `tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets mode 2 to 0, + collapsing the KV-tile axis so TMA always reads tile 0. + Fix: pre-slice with `tBgK[(None,0,None,0)]` to keep modes 0,2 free, + then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes mode 1 (the + surviving KV_tiles mode from the original mode 2) with kt. + This matches the CUTLASS reference FMHA pattern. Kernel structure: @@ -185,15 +181,17 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # NOTE: after tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions. - # Shape is (1,1,1,1,n_kv_tiles,1,1,1) in TMA coord space. - # Mode 4 is the GMEM tile-iteration axis. - # The no-op 8-None slice opens up the full TMA coordinate space. - # Without it, only 4 modes are visible and 8-mode indexing fails. - # The old (None,None,0,0) pre-slice collapsed the GMEM tile axis to 0. + # NOTE: after tma_partition, tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?). + # Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles). + # The old pre-slice (None,None,0,0) kept modes 0,1 free and set mode 2 to 0, + # collapsing the KV tile axis. Fix: keep modes 0,2 free with (None,0,None,0), + # then index with [None, kt] in cute.copy. This matches the CUTLASS reference + # pattern: 2D pre-slice, 2-mode indexing in copy. + # tVgV similarly has KV tiles at mode 2. + # tAgQ: (None,0,None,0) works because Q has only 1 tile. tAgQ = tAgQ[(None,0,None,0)] - tBgK = tBgK[(None,None,None,None,None,None,None,None)] - tVgV = tVgV[(None,None,None,None,None,None,None,None)] + tBgK = tBgK[(None,0,None,0)] + tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -217,13 +215,10 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp — fully unrolled ===== - # The old pre-slice tBgK[(None,None,0,0)] collapsed the 8-mode TMA - # coordinate space to 4 modes, pinning mode 4 (GMEM tile dim) to 0. - # The 8-None no-op pre-slice opens the full TMA coord space so we - # can index mode 4 with kt in cute.copy. The pipeline's - # acquire/release machinery still tracks the kv_stage ring buffer - # dynamically at runtime, so the producer correctly blocks on - # consumer release when n_kv_tiles > kv_stage. + # The old pre-slice (None,None,0,0) kept modes 0,1 free and set + # mode 2 (KV tiles) to 0. Fix: (None,0,None,0) keeps modes 0,2 free. + # Then [None, kt] indexes the KV tile axis correctly. + # Matches the CUTLASS reference FMHA TMA indexing pattern. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) @@ -231,18 +226,17 @@ class FmhaV3StageCMulti: kvp.reset() for kt in cutlass.range_constexpr(self.n_kv_tiles): kvh = kvp.acquire_and_advance() - # 8-mode TMA indexing: mode 4 = GMEM tile axis (size n_kv_tiles). - # The 8-None pre-slice above opened the full coord space. - # kt at mode 4 indexes the correct KV tile in GMEM. + # After (None,0,None,0) pre-slice, tBgK is 2D: (V_grouped, KV_tiles). + # [None, kt] indexes mode 1 (KV_tiles) with kt. cute.copy( tma_k, - tBgK[None, None, None, None, kt, None, None, None], + tBgK[None, kt], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier, ) cute.copy( tma_v, - tVgV[None, None, None, None, kt, None, None, None], + tVgV[None, kt], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier, ) diff --git a/tests/unit/test_fmha_v3_diag.py b/tests/unit/test_fmha_v3_diag.py index 30336de7..37bba2cb 100644 --- a/tests/unit/test_fmha_v3_diag.py +++ b/tests/unit/test_fmha_v3_diag.py @@ -140,13 +140,13 @@ class FmhaV3Diag: tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) # ===================================================================== # ⚠️⚠️⚠️ CRITICAL: TMA PARTITION TENSOR MODE ORDERING ⚠️⚠️⚠️ - # After tma_partition, tBgK/tVgV have 8 modes: (1,1,1,1,n_kv_tiles,1,1,1) + # After tma_partition, tBgK/tVgV have 4 modes: (V_grouped, ?, KV_tiles, ?) # Mode 4 is the GMEM tile dimension. DO NOT pre-slice to fewer modes! # See README.md for details. # ===================================================================== tAgQ = tAgQ[(None,0,None,0)] - tBgK = tBgK[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice! - tVgV = tVgV[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice! + tBgK = tBgK[(None,0,None,0)] + tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -179,10 +179,10 @@ class FmhaV3Diag: kv_coord = Int32(0 + 0) for kt in cutlass.range(self.n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) - # ⚠️ CRITICAL: kv_coord indexes MODE 4 of 8-mode tBgK/tVgV. + # kv_coord indexes the surviving mode 1 (from original mode 2) of 2D tBgK/tVgV. # Using (None, kv_coord) on a pre-sliced 4-mode tensor SILENTLY BREAKS multi-tile! - cute.copy(tma_k, tBgK[None, None, None, None, kv_coord, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[None, None, None, None, kv_coord, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[None, kv_coord], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[None, kv_coord], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) kv_coord += 1 pk = cutlass.Boolean(1) kvp.tail() diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 224d6b65..aaffca42 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -179,13 +179,13 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # After tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions. + # After tma_partition, tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?). # Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles). # 8-None no-op pre-slice opens the full TMA coord space. # tAgQ is fine with 4-mode slice (Q has only 1 tile). tAgQ = tAgQ[(None,0,None,0)] - tBgK = tBgK[(None,None,None,None,None,None,None,None)] - tVgV = tVgV[(None,None,None,None,None,None,None,None)] + tBgK = tBgK[(None,0,None,0)] + tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -223,8 +223,8 @@ class FmhaV3StageCMulti: kvp.reset(); pk = kvp.try_acquire() for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[None, kt], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[None, kt], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail()