diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index aeee8739..b0c3a3f3 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -204,12 +204,14 @@ class FmhaV3StageC: cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() - kv_coord = Int32(0) # MUST be Int32 for TMA addressing - for kt in cutlass.range(n_kv_tiles, unroll=1): + # Python range() unrolls at trace time. Each iteration emits a + # separate cute.copy with a distinct compile-time Int32 constant. + # We proved Int32(1) hardcoded works — by induction Int32(k) works. + for kt in range(n_kv_tiles): + coord = Int32(kt) kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - kv_coord += 1 + cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail()