diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index dac7347d..78f983ca 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -203,22 +203,23 @@ class FmhaV3StageCMulti: pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ===== TMA LOAD warp ===== - # GMEM tile coordinate: use the cutlass.range induction variable kt - # directly. CuTeDSL's `cutlass.range` doesn't auto-detect a Python `+=` - # rebinding as a loop-carried iter_args update — the JIT traces the - # body once and captures whatever value `kv_coord` had at trace time, - # so an outer `kv_coord = Int32(0)` plus a `kv_coord += 1` inside the - # loop bakes 0 into every iteration's TMA descriptor at runtime. - # The induction variable IS the loop-carried state, properly tracked. + # Following CUTLASS Blackwell FMHA reference pattern exactly: + # Pre-slice tBgK/tVgV (same as reference's tKgK = tKgK_kdl[None,None,0,coord]), + # then index with a loop-carried kv_coord variable inside cutlass.range. + # The reference uses kv_coord += 1 and it works — the key is using a + # Python variable with kv_coord = kv_coord + 1, NOT the cutlass.range + # induction variable directly (which gets constant-folded by JIT). if warp_idx == self.tma_warp_id: qp.reset(); qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() + kv_coord = 0 for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + kv_coord = kv_coord + 1 pk = cutlass.Boolean(1) kvp.tail()