Revert TMA to kt pattern (n=128 works), multi-tile TMA is separate bug

This commit is contained in:
2026-05-22 20:10:35 +00:00
parent dc9a5bc499
commit 52857aee16

View File

@@ -204,20 +204,18 @@ class FmhaV3StageCMulti:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
# Use the EXACT pattern from test_fmha_v3_diag.py which works for n=256.
# kv_coord = Int32(0+0) + kv_coord += 1 in cutlass.range(self.n_kv_tiles, unroll=1)
# with tBgK = tBgK[(None,None,0,0)] (GMEM tile dim free)
# NOTE: using kt from cutlass.range works for n=128 (single tile).
# Multi-tile (n>128) loads from tile 0 only — the JIT constant-folds kt.
# TODO: fix multi-tile TMA indexing (kv_coord pattern from diag test).
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
kv_coord = Int32(0 + 0)
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()