FIX: Use Python range() in TMA warp for concrete per-iteration GMEM coords
This commit is contained in:
@@ -207,25 +207,20 @@ class FmhaV3StageCMulti:
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp =====
|
||||
# GMEM tile coordinate: use a manual kv_coord counter.
|
||||
# The induction variable from cutlass.range gets constant-folded by
|
||||
# CuTeDSL's JIT — kt in tBgK[(None, kt)] bakes to 0 at runtime.
|
||||
# Force the initial value to be an SSA value (not a constant literal)
|
||||
# by computing zero as n_kv_tiles - n_kv_tiles. The JIT can't fold
|
||||
# this because n_kv_tiles comes from cute.size() which returns an
|
||||
# SSA Int32. Then kv_coord += 1 is properly tracked as iter_args
|
||||
# on the underlying scf.for.
|
||||
# GMEM tile coordinate: use Python range() so the JIT traces each
|
||||
# iteration separately with concrete kt values. cutlass.range generates
|
||||
# an scf.for where the induction variable gets constant-folded into
|
||||
# the TMA descriptor (always 0 at runtime). Plain range() unrolls at
|
||||
# trace time, giving each iteration a distinct static coordinate.
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
kv_coord = n_kv_tiles - n_kv_tiles # Force SSA zero (not constant-folded)
|
||||
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
|
||||
for kt in range(n_kv_tiles):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord = kv_coord + Int32(1)
|
||||
cute.copy(tma_k, tBgK[(None, Int32(kt))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user