From 07817ae82e058f5ed2fcfe321cc1f0551be8b356 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 15:52:56 +0000 Subject: [PATCH] FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0. Instead, keep the original tBgK and index with (None, kt, None, 0) inside the TMA loop, where kt selects the correct GMEM tile. This preserves 2D rank matching with the SMEM tensor. --- tests/fmha_v3_stage_c_example1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/fmha_v3_stage_c_example1.py b/tests/fmha_v3_stage_c_example1.py index f0811995..a4da823e 100644 --- a/tests/fmha_v3_stage_c_example1.py +++ b/tests/fmha_v3_stage_c_example1.py @@ -177,7 +177,7 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,None,None,0)] + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -214,11 +214,11 @@ class FmhaV3StageCMulti: for kt in cutlass.range(n_kv_tiles, unroll=1): kh = kvp.acquire_and_advance(pk) # GMEM tile: kt (correct K[kt]). SMEM slot: kh.index (ring buffer). - cute.copy(tma_k, tBgK[(None, kt, None)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier) + cute.copy(tma_k, tBgK[(None, kt, None, 0)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier) pk = cutlass.Boolean(1) vh = kvp.acquire_and_advance(pk) # GMEM tile: kt (correct V[kt]). SMEM slot: vh.index (ring buffer). - cute.copy(tma_v, tVgV[(None, kt, None)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier) + cute.copy(tma_v, tVgV[(None, kt, None, 0)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier) pk = cutlass.Boolean(1) kvp.tail()