Stage C: manual kv_coord + correct K GMEM slice + O rescale fence
Key fixes: 1. GMEM tile coord: manual Int32 kv_coord (not kvh.count) 2. K GMEM slice: (None,None,0,0) keeps mode 1 free (GMEM iter) 3. V GMEM slice: (None,0,None,0) keeps mode 2 free (GMEM iter) 4. Add fence_view_async_tmem_load before O rescale for visibility
This commit is contained in:
@@ -167,13 +167,13 @@ class FmhaV3StageC:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# Only slice GMEM tensors (SMEM tensors from tma_partition are already 2D).
|
||||
# K from QK MMA: GMEM iter at mode 1 → slice [(None,None,0,0)] keeps modes 0,1
|
||||
# V from PV MMA: GMEM iter at mode 2 → slice [(None,0,None,0)] keeps modes 0,2
|
||||
# Q: 1 tile only, original slice is fine.
|
||||
tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode GMEM iter to 0
|
||||
tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free for kvh.count
|
||||
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free for kvh.count
|
||||
# GMEM slices: keep the GMEM iteration mode free for kv_coord indexing.
|
||||
# CUTLASS reference: tKgK = tKgK_kdl[None, None, 0, 0] (keeps TMA_atom + GMEM_iter)
|
||||
# tVgV = tVgV_dkl[None, 0, None, 0] (keeps TMA_atom + GMEM_iter at mode 2)
|
||||
# SMEM tensors from tma_partition are already 2D — don't re-slice them.
|
||||
tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine
|
||||
tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free
|
||||
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -198,20 +198,19 @@ class FmhaV3StageC:
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp =====
|
||||
# One acquire per kt; K and V both target kvh.barrier. kvh.count == kt.
|
||||
# GMEM tile coordinate: manual Int32 counter (kv_coord). SMEM slot: kvh.index.
|
||||
# Pipeline handle .count is NOT a usable GMEM coordinate.
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
kv_coord = Int32(0)
|
||||
for kt in cutlass.range(n_kv_tiles, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
# Both transfers decrement the same barrier's tx_count.
|
||||
# kvh.count is a pipeline-state Int32 (the form cute.copy accepts).
|
||||
# K: mode 1 is GMEM iter → tBgK[(None, kvh.count)]
|
||||
# V: mode 2 is GMEM iter → tVgV[(None, kvh.count)]
|
||||
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord += 1
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
@@ -351,9 +350,10 @@ class FmhaV3StageC:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# O rescale for kt > 0. Reads O written by MMA's PV[kt-1];
|
||||
# visibility is provided by s_cons.wait_and_advance above
|
||||
# (acquires on MMA's S[kt] commit, which orders PV[kt-1] before).
|
||||
# s_cons.wait_and_advance acquires on MMA's S[kt] commit (sequenced
|
||||
# after PV[kt-1]), but add an explicit tmem_load fence for safety.
|
||||
if kt > 0:
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for i in range(o_col_tiles):
|
||||
tTMEM_LOAD_O_i = cute.make_tensor(
|
||||
tTMEM_LOAD_OtO.iterator + i * corr_tile_size,
|
||||
|
||||
Reference in New Issue
Block a user