Try cutlass.range with Int32(kt) — now n_kv_tiles is Python int

This commit is contained in:
2026-05-22 17:51:25 +00:00
parent bf80fbee99
commit d2bbdd59f6

View File

@@ -205,10 +205,9 @@ class FmhaV3StageC:
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
# Python range() unrolls at trace time. Each iteration emits a
# separate cute.copy with a distinct compile-time Int32 constant.
# We proved Int32(1) hardcoded works — by induction Int32(k) works.
for kt in range(self.n_kv_tiles):
# Use cutlass.range with Python int n_kv_tiles for proper pipeline
# semantics (acquire/release). Wrap kt in Int32() for TMA coordinate.
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
coord = Int32(kt)
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)