D1: Use cutlass.range for k_sub loops (CuTeDSL immutable handle)

This commit is contained in:
2026-05-24 06:43:30 +00:00
parent 2bf3ee40aa
commit dd39c2ebdf

View File

@@ -216,8 +216,8 @@ class FmhaKernel:
# K sub-tiling path (hd=512): load Q and K per k_sub
qp.reset()
kvp.reset()
for kt in range(self.n_kv_tiles):
for k_sub in range(self.n_k_sub_tiles):
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
# Load Q[k_sub] → sQ
qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
@@ -248,8 +248,8 @@ class FmhaKernel:
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd=512): pipeline sync for k_sub iterations
kvh = kvc.wait_and_advance() # initialize kvh before loops
for kt in range(self.n_kv_tiles):
for k_sub in range(self.n_k_sub_tiles):
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
qh = qc.wait_and_advance(); qh.release()
kvh = kvc.wait_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0)