From b3778896b9f174545ed0022d9c3c19fc6fde767b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 20:33:40 +0000 Subject: [PATCH] =?UTF-8?q?Test:=20kv=5Fcoord=20=3D=20warp=5Fidx()=20*=200?= =?UTF-8?q?=20=E2=80=94=20force=20SSA=20from=20runtime=20value?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_fmha_v3_stage_c.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 6253651f..1403ce2c 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -224,9 +224,12 @@ class FmhaV3StageCMulti: cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset() - kv_coord = n_kv_tiles - n_kv_tiles # SSA runtime zero + # Force kv_coord to be a runtime SSA value. + # The JIT folds n_kv_tiles - n_kv_tiles and Int32(0) to compile-time constants. + # cute.arch.make_warp_uniform(cute.arch.warp_idx()) is a runtime value. + # Multiplying by 0 gives a runtime zero the JIT can't constant-fold. + kv_coord = cute.arch.make_warp_uniform(cute.arch.warp_idx()) * 0 for kt in range(n_kv_tiles): - cute.printf("TMA kt=%d kv_coord=%d\n", kt, kv_coord) kvh = kvp.acquire_and_advance() cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)