D1: Use cutlass.range for k_sub loops (CuTeDSL immutable handle)
This commit is contained in:
@@ -216,8 +216,8 @@ class FmhaKernel:
|
||||
# K sub-tiling path (hd=512): load Q and K per k_sub
|
||||
qp.reset()
|
||||
kvp.reset()
|
||||
for kt in range(self.n_kv_tiles):
|
||||
for k_sub in range(self.n_k_sub_tiles):
|
||||
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
|
||||
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
|
||||
# Load Q[k_sub] → sQ
|
||||
qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
@@ -248,8 +248,8 @@ class FmhaKernel:
|
||||
if const_expr(self.n_k_sub_tiles > 1):
|
||||
# K sub-tiling path (hd=512): pipeline sync for k_sub iterations
|
||||
kvh = kvc.wait_and_advance() # initialize kvh before loops
|
||||
for kt in range(self.n_kv_tiles):
|
||||
for k_sub in range(self.n_k_sub_tiles):
|
||||
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
|
||||
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
|
||||
qh = qc.wait_and_advance(); qh.release()
|
||||
kvh = kvc.wait_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0)
|
||||
|
||||
Reference in New Issue
Block a user