DEBUG: use Int32(kt) directly to test if coordinate matters

This commit is contained in:
2026-05-22 20:34:03 +00:00
parent c01291f16a
commit b7a1deed52

View File

@@ -224,16 +224,18 @@ class FmhaV3StageCMulti:
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset()
# Force kv_coord to be a runtime SSA value.
# The JIT folds n_kv_tiles - n_kv_tiles and Int32(0) to compile-time constants.
# cute.arch.make_warp_uniform(cute.arch.warp_idx()) is a runtime value.
# Multiplying by 0 gives a runtime zero the JIT can't constant-fold.
kv_coord = cute.arch.make_warp_uniform(cute.arch.warp_idx()) * 0
# Try using the pipeline state count (kh.count) as the coordinate.
# This is what the CUTLASS reference's "mode 1" does — the pipeline
# index IS the GMEM tile index for 2-stage pipelines with 2 KV tiles.
# For more tiles, we need a separate counter.
# But first, let's test if the coordinate matters at all by using
# Int32(1) for the second tile when n=256.
for kt in range(n_kv_tiles):
kvh = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord = kv_coord + 1
# Use a hardcoded coord to test if TMA even reads different tiles
coord = Int32(kt) # Should be 0, 1, 2, ... but might be constant-folded
cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kvp.tail()
# ===== MMA warp =====