FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing

The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0.
Instead, keep the original tBgK and index with (None, kt, None, 0)
inside the TMA loop, where kt selects the correct GMEM tile.
This preserves 2D rank matching with the SMEM tensor.
This commit is contained in:
2026-05-22 15:52:56 +00:00
parent 1ad243f095
commit 07817ae82e

View File

@@ -177,7 +177,7 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,None,None,0)]
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -214,11 +214,11 @@ class FmhaV3StageCMulti:
for kt in cutlass.range(n_kv_tiles, unroll=1):
kh = kvp.acquire_and_advance(pk)
# GMEM tile: kt (correct K[kt]). SMEM slot: kh.index (ring buffer).
cute.copy(tma_k, tBgK[(None, kt, None)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
cute.copy(tma_k, tBgK[(None, kt, None, 0)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
# GMEM tile: kt (correct V[kt]). SMEM slot: vh.index (ring buffer).
cute.copy(tma_v, tVgV[(None, kt, None)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
cute.copy(tma_v, tVgV[(None, kt, None, 0)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()