From b7a1deed529b9606dbce8326a4cd5ec04d55c26e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 20:34:03 +0000 Subject: [PATCH] DEBUG: use Int32(kt) directly to test if coordinate matters --- tests/unit/test_fmha_v3_stage_c.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 1403ce2c..5a1c63e3 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -224,16 +224,18 @@ class FmhaV3StageCMulti: cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset() - # Force kv_coord to be a runtime SSA value. - # The JIT folds n_kv_tiles - n_kv_tiles and Int32(0) to compile-time constants. - # cute.arch.make_warp_uniform(cute.arch.warp_idx()) is a runtime value. - # Multiplying by 0 gives a runtime zero the JIT can't constant-fold. - kv_coord = cute.arch.make_warp_uniform(cute.arch.warp_idx()) * 0 + # Try using the pipeline state count (kh.count) as the coordinate. + # This is what the CUTLASS reference's "mode 1" does — the pipeline + # index IS the GMEM tile index for 2-stage pipelines with 2 KV tiles. + # For more tiles, we need a separate counter. + # But first, let's test if the coordinate matters at all by using + # Int32(1) for the second tile when n=256. for kt in range(n_kv_tiles): kvh = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - kv_coord = kv_coord + 1 + # Use a hardcoded coord to test if TMA even reads different tiles + coord = Int32(kt) # Should be 0, 1, 2, ... but might be constant-folded + cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) kvp.tail() # ===== MMA warp =====