diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index e1a7b243..a4a4f386 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -226,15 +226,16 @@ class FmhaV3StageCMulti: cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset() - # SMEM-backed counter — the JIT can't constant-fold SMEM reads. - # Initialize to 0 before the loop, read/write each iteration. - kv_coord_smem[0] = cutlass.Int32(0) + # Use pipeline state index as a proxy for the GMEM tile coordinate. + # The pipeline state IS properly tracked by the JIT as dynamic. + # For 2-stage pipelines with 2 KV tiles, kvh.index cycles 0,1. + # This maps directly to the GMEM tile index. + # For >2 tiles, we need the SMEM counter to track the iteration. for kt in range(n_kv_tiles): - kv_coord = kv_coord_smem[0] # Dynamic read from SMEM kvh = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - kv_coord_smem[0] = kv_coord + 1 # Write back to SMEM + # Use kvh.index as the coordinate (0 for first stage, 1 for second) + cute.copy(tma_k, tBgK[(None, kvh.index)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kvh.index)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) kvp.tail() # ===== MMA warp =====