Option 2: Python range() with Int32(kt) for TMA GMEM coord

cutlass.range traces once - kv_coord/kt are trace-time values,
not runtime loop-carried state. Python range() fully unrolls at
trace time, emitting distinct Int32(k) constants per iteration.
Int32(1) hardcoded already proved TMA CAN load from tile 1.
This commit is contained in:
2026-05-22 17:47:43 +00:00
parent 06a0f5fc53
commit 82c46d438e

View File

@@ -204,12 +204,14 @@ class FmhaV3StageC:
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
kv_coord = Int32(0) # MUST be Int32 for TMA addressing
for kt in cutlass.range(n_kv_tiles, unroll=1):
# Python range() unrolls at trace time. Each iteration emits a
# separate cute.copy with a distinct compile-time Int32 constant.
# We proved Int32(1) hardcoded works — by induction Int32(k) works.
for kt in range(n_kv_tiles):
coord = Int32(kt)
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1
cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()