FIX: consistent GMEM/SMEM slicing for K and V TMA partitions

Both GMEM and SMEM sides must be sliced to the same rank for cute.copy.
K (QK MMA B-partition): slice [(None,None,0,0)] keeps modes 0,1
  - mode 1 = GMEM iteration, indexed by kvh.count
V (PV MMA B-partition): slice [(None,0,None,0)] keeps modes 0,2
  - mode 2 = GMEM iteration, indexed by kvh.count
Q: only 1 tile, (None,0,None,0) hardcode is fine.
This commit is contained in:
2026-05-22 16:56:38 +00:00
parent 5c1423b2c5
commit 15a2fdadbc
2 changed files with 14 additions and 11 deletions

View File

@@ -137,7 +137,9 @@ class FmhaV3Diag:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,None,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,0,None,0)]
tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)]
tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)]
tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -163,12 +165,12 @@ class FmhaV3Diag:
# TMA LOAD (combined K+V barrier)
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count, None)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kvh.count, None)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()

View File

@@ -171,10 +171,11 @@ class FmhaV3StageC:
# K: GMEM iter at mode 1 (from QK MMA B-partition, layout K,D,L)
# V: GMEM iter at mode 2 (from PV MMA B-partition, layout D,K,L)
# Must keep GMEM iter dimension FREE for multi-tile KV loads.
# CUTLASS reference: tQgQ[None,None,0,0], tKgK[None,None,0,0], tVgV[None,0,None,0]
tAgQ = tAgQ[(None,None,None,0)] # Q: keep mode 1 free (only 1 tile, doesn't matter)
tBgK = tBgK[(None,None,None,0)] # K: keep mode 1 free for kvh.count
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 free for kvh.count
# CUTLASS reference slices: tQgQ[None,None,0,0], tKgK[None,None,0,0], tVgV[None,0,None,0]
# Both GMEM and SMEM sides must be sliced consistently for cute.copy rank matching.
tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile only
tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter)
tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter)
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -202,16 +203,16 @@ class FmhaV3StageC:
# One acquire per kt; K and V both target kvh.barrier. kvh.count == kt.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count, None)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
# Both transfers decrement the same barrier's tx_count.
# kvh.count is a pipeline-state Int32 (the form cute.copy accepts).
# K: GMEM iter at mode 1, so index (None, kvh.count, None)
# V: GMEM iter at mode 2, so index (None, kvh.count)
cute.copy(tma_k, tBgK[(None, kvh.count, None)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
# K: mode 1 is GMEM iter → tBgK[(None, kvh.count)]
# V: mode 2 is GMEM iter → tVgV[(None, kvh.count)]
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()