Stage C: manual kv_coord + correct K GMEM slice + O rescale fence

Key fixes:
1. GMEM tile coord: manual Int32 kv_coord (not kvh.count)
2. K GMEM slice: (None,None,0,0) keeps mode 1 free (GMEM iter)
3. V GMEM slice: (None,0,None,0) keeps mode 2 free (GMEM iter)
4. Add fence_view_async_tmem_load before O rescale for visibility
This commit is contained in:
2026-05-22 17:26:56 +00:00
parent 5a6c4d2cd2
commit ba2cefb668

View File

@@ -167,13 +167,13 @@ class FmhaV3StageC:
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
# Only slice GMEM tensors (SMEM tensors from tma_partition are already 2D).
# K from QK MMA: GMEM iter at mode 1 → slice [(None,None,0,0)] keeps modes 0,1
# V from PV MMA: GMEM iter at mode 2 → slice [(None,0,None,0)] keeps modes 0,2
# Q: 1 tile only, original slice is fine.
tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode GMEM iter to 0
tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free for kvh.count
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free for kvh.count
# GMEM slices: keep the GMEM iteration mode free for kv_coord indexing.
# CUTLASS reference: tKgK = tKgK_kdl[None, None, 0, 0] (keeps TMA_atom + GMEM_iter)
# tVgV = tVgV_dkl[None, 0, None, 0] (keeps TMA_atom + GMEM_iter at mode 2)
# SMEM tensors from tma_partition are already 2D — don't re-slice them.
tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine
tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
@@ -198,20 +198,19 @@ class FmhaV3StageC:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
# One acquire per kt; K and V both target kvh.barrier. kvh.count == kt.
# GMEM tile coordinate: manual Int32 counter (kv_coord). SMEM slot: kvh.index.
# Pipeline handle .count is NOT a usable GMEM coordinate.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
kv_coord = Int32(0)
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
# Both transfers decrement the same barrier's tx_count.
# kvh.count is a pipeline-state Int32 (the form cute.copy accepts).
# K: mode 1 is GMEM iter → tBgK[(None, kvh.count)]
# V: mode 2 is GMEM iter → tVgV[(None, kvh.count)]
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1
pk = cutlass.Boolean(1)
kvp.tail()
@@ -351,9 +350,10 @@ class FmhaV3StageC:
cute.arch.fence_view_async_tmem_store()
# O rescale for kt > 0. Reads O written by MMA's PV[kt-1];
# visibility is provided by s_cons.wait_and_advance above
# (acquires on MMA's S[kt] commit, which orders PV[kt-1] before).
# s_cons.wait_and_advance acquires on MMA's S[kt] commit (sequenced
# after PV[kt-1]), but add an explicit tmem_load fence for safety.
if kt > 0:
cute.arch.fence_view_async_tmem_load()
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + i * corr_tile_size,