DEBUG: hardcoded Int32(1) to test if TMA can read tile 1
This commit is contained in:
@@ -224,18 +224,11 @@ class FmhaV3StageCMulti:
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset()
|
||||
# Try using the pipeline state count (kh.count) as the coordinate.
|
||||
# This is what the CUTLASS reference's "mode 1" does — the pipeline
|
||||
# index IS the GMEM tile index for 2-stage pipelines with 2 KV tiles.
|
||||
# For more tiles, we need a separate counter.
|
||||
# But first, let's test if the coordinate matters at all by using
|
||||
# Int32(1) for the second tile when n=256.
|
||||
# DEBUG: Use constant Int32(1) to test if TMA can read from tile 1 at all
|
||||
for kt in range(n_kv_tiles):
|
||||
kvh = kvp.acquire_and_advance()
|
||||
# Use a hardcoded coord to test if TMA even reads different tiles
|
||||
coord = Int32(kt) # Should be 0, 1, 2, ... but might be constant-folded
|
||||
cute.copy(tma_k, tBgK[(None, coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[(None, Int32(1))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, Int32(1))], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
|
||||
Reference in New Issue
Block a user