From 0dd6fefd66b645334ae1c86c1e92cdb84b8c32db Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 18:25:13 +0000 Subject: [PATCH] FIX: Force SSA GMEM coord via n_kv_tiles - n_kv_tiles instead of cutlass.range kt --- tests/fmha_v3_stage_c_example6.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/fmha_v3_stage_c_example6.py b/tests/fmha_v3_stage_c_example6.py index 912e8fd5..fd7e05ec 100644 --- a/tests/fmha_v3_stage_c_example6.py +++ b/tests/fmha_v3_stage_c_example6.py @@ -207,22 +207,25 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp ===== - # GMEM tile coordinate: use the cutlass.range induction variable kt - # directly. CuTeDSL's `cutlass.range` doesn't auto-detect a Python `+=` - # rebinding as a loop-carried iter_args update — the JIT traces the - # body once and captures whatever value `kv_coord` had at trace time, - # so an outer `kv_coord = Int32(0)` plus a `kv_coord += 1` inside the - # loop bakes 0 into every iteration's TMA descriptor at runtime. - # The induction variable IS the loop-carried state, properly tracked. + # GMEM tile coordinate: use a manual kv_coord counter. + # The induction variable from cutlass.range gets constant-folded by + # CuTeDSL's JIT — kt in tBgK[(None, kt)] bakes to 0 at runtime. + # Force the initial value to be an SSA value (not a constant literal) + # by computing zero as n_kv_tiles - n_kv_tiles. The JIT can't fold + # this because n_kv_tiles comes from cute.size() which returns an + # SSA Int32. Then kv_coord += 1 is properly tracked as iter_args + # on the underlying scf.for. if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() + kv_coord = n_kv_tiles - n_kv_tiles # Force SSA zero (not constant-folded) for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + kv_coord = kv_coord + Int32(1) pk = cutlass.Boolean(1) kvp.tail()