diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 29fdf6c7..794580d9 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -216,8 +216,8 @@ class FmhaKernel: # K sub-tiling path (hd=512): load Q and K per k_sub qp.reset() kvp.reset() - for kt in range(self.n_kv_tiles): - for k_sub in range(self.n_k_sub_tiles): + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): # Load Q[k_sub] → sQ qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) @@ -248,8 +248,8 @@ class FmhaKernel: if const_expr(self.n_k_sub_tiles > 1): # K sub-tiling path (hd=512): pipeline sync for k_sub iterations kvh = kvc.wait_and_advance() # initialize kvh before loops - for kt in range(self.n_kv_tiles): - for k_sub in range(self.n_k_sub_tiles): + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): qh = qc.wait_and_advance(); qh.release() kvh = kvc.wait_and_advance() qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0)