Fix multi-tile TMA: loop-carried kv_coord (CUTLASS reference pattern)

This commit is contained in:
2026-05-22 22:25:00 +00:00
parent 2a14c2dd18
commit 16f60e2dd1

View File

@@ -203,22 +203,23 @@ class FmhaV3StageCMulti:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
# GMEM tile coordinate: use the cutlass.range induction variable kt
# directly. CuTeDSL's `cutlass.range` doesn't auto-detect a Python `+=`
# rebinding as a loop-carried iter_args update — the JIT traces the
# body once and captures whatever value `kv_coord` had at trace time,
# so an outer `kv_coord = Int32(0)` plus a `kv_coord += 1` inside the
# loop bakes 0 into every iteration's TMA descriptor at runtime.
# The induction variable IS the loop-carried state, properly tracked.
# Following CUTLASS Blackwell FMHA reference pattern exactly:
# Pre-slice tBgK/tVgV (same as reference's tKgK = tKgK_kdl[None,None,0,coord]),
# then index with a loop-carried kv_coord variable inside cutlass.range.
# The reference uses kv_coord += 1 and it works — the key is using a
# Python variable with kv_coord = kv_coord + 1, NOT the cutlass.range
# induction variable directly (which gets constant-folded by JIT).
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
kv_coord = 0
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord = kv_coord + 1
pk = cutlass.Boolean(1)
kvp.tail()