FIX: keep GMEM iteration dimension FREE in TMA K/V partition slices

Root cause of multi-tile failure: (None,0,None,0) slice hardcodes the
GMEM tile dimension to 0, so TMA always loads from tile 0 regardless
of kvh.count. K from QK MMA has GMEM iter at mode 1, V from PV MMA
has it at mode 2 (different layouts: K,D,L vs D,K,L).

Fix follows CUTLASS FMHA reference:
- K: tBgK[(None,None,None,0)] + tBgK[(None, kvh.count, None)]
- V: tVgV[(None,0,None,0)] + tVgV[(None, kvh.count)]
This commit is contained in:
2026-05-22 16:51:57 +00:00
parent 04da36e18c
commit 7aaf9ccbda
2 changed files with 15 additions and 6 deletions

View File

@@ -137,7 +137,7 @@ class FmhaV3Diag:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tAgQ = tAgQ[(None,None,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -163,12 +163,12 @@ class FmhaV3Diag:
# TMA LOAD (combined K+V barrier)
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
cute.copy(tma_q, tAgQ[(None, qh.count, None)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, kvh.count, None)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()

View File

@@ -167,7 +167,14 @@ class FmhaV3StageC:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
# Q: GMEM iter at mode 1 (from QK MMA A-partition, layout Q,D,L)
# K: GMEM iter at mode 1 (from QK MMA B-partition, layout K,D,L)
# V: GMEM iter at mode 2 (from PV MMA B-partition, layout D,K,L)
# Must keep GMEM iter dimension FREE for multi-tile KV loads.
# CUTLASS reference: tQgQ[None,None,0,0], tKgK[None,None,0,0], tVgV[None,0,None,0]
tAgQ = tAgQ[(None,None,None,0)] # Q: keep mode 1 free (only 1 tile, doesn't matter)
tBgK = tBgK[(None,None,None,0)] # K: keep mode 1 free for kvh.count
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 free for kvh.count
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -195,14 +202,16 @@ class FmhaV3StageC:
# One acquire per kt; K and V both target kvh.barrier. kvh.count == kt.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
cute.copy(tma_q, tAgQ[(None, qh.count, None)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
# Both transfers decrement the same barrier's tx_count.
# kvh.count is a pipeline-state Int32 (the form cute.copy accepts).
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
# K: GMEM iter at mode 1, so index (None, kvh.count, None)
# V: GMEM iter at mode 2, so index (None, kvh.count)
cute.copy(tma_k, tBgK[(None, kvh.count, None)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()