Fix SMEM counter type: cutlass.Int32 for MemRange
This commit is contained in:
@@ -148,7 +148,7 @@ class FmhaV3StageCMulti:
|
||||
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
|
||||
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
kv_coord_smem: cute.struct.MemRange[Int32, 1]
|
||||
kv_coord_storage: cute.struct.MemRange[cutlass.Int32, 1]
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
@@ -226,14 +226,13 @@ class FmhaV3StageCMulti:
|
||||
qp.tail()
|
||||
kvp.reset()
|
||||
# Use SMEM-backed counter — the JIT can't constant-fold SMEM reads.
|
||||
# Initialize to 0, read before each copy, increment after.
|
||||
st.kv_coord_smem[0] = Int32(0)
|
||||
st.kv_coord_storage[0] = cutlass.Int32(0)
|
||||
for kt in range(n_kv_tiles):
|
||||
kv_coord = st.kv_coord_smem[0] # Dynamic read from SMEM
|
||||
kv_coord = st.kv_coord_storage[0] # Dynamic read from SMEM
|
||||
kvh = kvp.acquire_and_advance()
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
st.kv_coord_smem[0] = kv_coord + 1 # Write back to SMEM
|
||||
st.kv_coord_storage[0] = kv_coord + 1 # Write back to SMEM
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
|
||||
Reference in New Issue
Block a user