From 52857aee16c2f82d9b19c263af4aef0187e216c2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 20:10:35 +0000 Subject: [PATCH] Revert TMA to kt pattern (n=128 works), multi-tile TMA is separate bug --- tests/fmha_v3_stage_c_example7.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/fmha_v3_stage_c_example7.py b/tests/fmha_v3_stage_c_example7.py index 7b7f931c..f322bd20 100644 --- a/tests/fmha_v3_stage_c_example7.py +++ b/tests/fmha_v3_stage_c_example7.py @@ -204,20 +204,18 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp ===== - # Use the EXACT pattern from test_fmha_v3_diag.py which works for n=256. - # kv_coord = Int32(0+0) + kv_coord += 1 in cutlass.range(self.n_kv_tiles, unroll=1) - # with tBgK = tBgK[(None,None,0,0)] (GMEM tile dim free) + # NOTE: using kt from cutlass.range works for n=128 (single tile). + # Multi-tile (n>128) loads from tile 0 only — the JIT constant-folds kt. + # TODO: fix multi-tile TMA indexing (kv_coord pattern from diag test). if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() - kv_coord = Int32(0 + 0) - for kt in cutlass.range(self.n_kv_tiles, unroll=1): + for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - kv_coord += 1 + cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) pk = cutlass.Boolean(1) kvp.tail()