Diag: try runtime Int32(0+0) for kv_coord with cutlass.range

This commit is contained in:
2026-05-22 17:57:58 +00:00
parent e5030cbea5
commit 601e662dd4

View File

@@ -167,11 +167,13 @@ class FmhaV3Diag:
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in range(self.n_kv_tiles):
coord = Int32(kt)
# Force runtime Int32 (not literal) — option 3 from CUTLASS LLM
kv_coord = Int32(0 + 0)
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1
pk = cutlass.Boolean(1)
kvp.tail()