From f2b1cc39d30b1a7dedf0c00ebc560c4e6b2271d0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 20:35:42 +0000 Subject: [PATCH] SMEM counter: separate allocate_tensor instead of struct field --- tests/unit/test_fmha_v3_stage_c.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 7a49867b..e1a7b243 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -148,8 +148,9 @@ class FmhaV3StageCMulti: s_bar: cute.struct.MemRange[cutlass.Int64, 2] acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 - kv_coord_storage: cute.struct.MemRange[cutlass.Int32, 1] smem = utils.SmemAllocator(); st = smem.allocate(SS) + # Separate SMEM allocation for the TMA GMEM coordinate counter + kv_coord_smem = smem.allocate_tensor(element_type=cutlass.Int32, layout=cute.make_layout(1), byte_alignment=4) qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() @@ -225,14 +226,15 @@ class FmhaV3StageCMulti: cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset() - # Use SMEM-backed counter — the JIT can't constant-fold SMEM reads. - st.kv_coord_storage[0] = cutlass.Int32(0) + # SMEM-backed counter — the JIT can't constant-fold SMEM reads. + # Initialize to 0 before the loop, read/write each iteration. + kv_coord_smem[0] = cutlass.Int32(0) for kt in range(n_kv_tiles): - kv_coord = st.kv_coord_storage[0] # Dynamic read from SMEM + kv_coord = kv_coord_smem[0] # Dynamic read from SMEM kvh = kvp.acquire_and_advance() cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - st.kv_coord_storage[0] = kv_coord + 1 # Write back to SMEM + kv_coord_smem[0] = kv_coord + 1 # Write back to SMEM kvp.tail() # ===== MMA warp =====