Try kt (cutlass.range induction) with correct (None,0,None,0) pre-slice
This commit is contained in:
@@ -214,17 +214,10 @@ class FmhaV3StageCMulti:
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
# CUTLASS reference initializes kv_coord from a runtime parameter
|
||||
# (seqlen_kv_loop_start), not a Python literal. A Python 0 gets
|
||||
# constant-folded by the JIT, so kv_coord += 1 never propagates.
|
||||
# n_kv_tiles - n_kv_tiles is a runtime expression that evaluates to 0
|
||||
# but the JIT can't fold it, forcing kv_coord to be a tracked SSA reg.
|
||||
kv_coord = n_kv_tiles - n_kv_tiles
|
||||
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord = kv_coord + 1
|
||||
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user