CRITICAL FIX: keep GMEM iteration dim free in tBgK/tVgV slice

The slice (None,0,None,0) was hardcoding the GMEM iteration dim to 0,
meaning TMA always loaded K/V from tile 0 regardless of kt.
Changed to (None,None,None,0) to keep gmem_iter free,
then index with (None, kt, None) in the TMA copy loop.

This is the root cause of multi-tile failure: TMA was always reading
the first 128 tokens for ALL KV tiles.
This commit is contained in:
2026-05-22 15:52:06 +00:00
parent a04b219f0f
commit 3d2cb0e52b

View File

@@ -177,7 +177,7 @@ class FmhaV3StageCMulti:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,None,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -214,11 +214,11 @@ class FmhaV3StageCMulti:
for kt in cutlass.range(n_kv_tiles, unroll=1):
kh = kvp.acquire_and_advance(pk)
# GMEM tile: kt (correct K[kt]). SMEM slot: kh.index (ring buffer).
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
cute.copy(tma_k, tBgK[(None, kt, None)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
# GMEM tile: kt (correct V[kt]). SMEM slot: vh.index (ring buffer).
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
cute.copy(tma_v, tVgV[(None, kt, None)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()