Test: use kvh.index (pipeline state) as TMA GMEM coordinate
This commit is contained in:
@@ -226,15 +226,16 @@ class FmhaV3StageCMulti:
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset()
|
||||
# SMEM-backed counter — the JIT can't constant-fold SMEM reads.
|
||||
# Initialize to 0 before the loop, read/write each iteration.
|
||||
kv_coord_smem[0] = cutlass.Int32(0)
|
||||
# Use pipeline state index as a proxy for the GMEM tile coordinate.
|
||||
# The pipeline state IS properly tracked by the JIT as dynamic.
|
||||
# For 2-stage pipelines with 2 KV tiles, kvh.index cycles 0,1.
|
||||
# This maps directly to the GMEM tile index.
|
||||
# For >2 tiles, we need the SMEM counter to track the iteration.
|
||||
for kt in range(n_kv_tiles):
|
||||
kv_coord = kv_coord_smem[0] # Dynamic read from SMEM
|
||||
kvh = kvp.acquire_and_advance()
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord_smem[0] = kv_coord + 1 # Write back to SMEM
|
||||
# Use kvh.index as the coordinate (0 for first stage, 1 for second)
|
||||
cute.copy(tma_k, tBgK[(None, kvh.index)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kvh.index)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
|
||||
Reference in New Issue
Block a user