Test: kv_coord = warp_idx() * 0 — force SSA from runtime value
This commit is contained in:
@@ -224,9 +224,12 @@ class FmhaV3StageCMulti:
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset()
|
||||
kv_coord = n_kv_tiles - n_kv_tiles # SSA runtime zero
|
||||
# Force kv_coord to be a runtime SSA value.
|
||||
# The JIT folds n_kv_tiles - n_kv_tiles and Int32(0) to compile-time constants.
|
||||
# cute.arch.make_warp_uniform(cute.arch.warp_idx()) is a runtime value.
|
||||
# Multiplying by 0 gives a runtime zero the JIT can't constant-fold.
|
||||
kv_coord = cute.arch.make_warp_uniform(cute.arch.warp_idx()) * 0
|
||||
for kt in range(n_kv_tiles):
|
||||
cute.printf("TMA kt=%d kv_coord=%d\n", kt, kv_coord)
|
||||
kvh = kvp.acquire_and_advance()
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
|
||||
Reference in New Issue
Block a user