diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index dffbabde..b65822ae 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -178,14 +178,15 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # CRITICAL: tBgK/tVgV have 8 modes after tma_partition. - # Mode 4 is the GMEM tile iteration axis. Pre-slicing with - # (None,None,0,0) collapses modes 4-7 to 0 — TMA always reads tile 0. - # Fix: 8-None no-op slice preserves all modes; (None, kt) in copy - # addresses mode 4 correctly. + # CRITICAL: After tma_partition, tBgK has 4 modes: (((64,128),1),?,?,?) + # Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles). + # The old pre-slice (None,None,0,0) set mode 2 to 0, so TMA always + # read tile 0. Fix: don't pre-slice; use 4-mode indexing in cute.copy + # with kt at mode 2. + # tVgV similarly has mode 2 as the GMEM tile dim. + # tAgQ is fine — Q has only 1 tile, no iteration needed. tAgQ = tAgQ[(None,0,None,0)] - tBgK = tBgK[(None,None,None,None,None,None,None,None)] - tVgV = tVgV[(None,None,None,None,None,None,None,None)] + # No pre-slice for tBgK/tVgV — we index all 4 modes in cute.copy tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -221,8 +222,9 @@ class FmhaV3StageCMulti: # correctly addresses mode 4 (GMEM tile dim) in cute.copy. for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + # 4-mode indexing: mode 2 = GMEM tile dim (n_kv_tiles) + cute.copy(tma_k, tBgK[(None, None, kt, 0)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, 0, kt, 0)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) kvp.tail() # ===== MMA warp =====