diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 300ca8c0..d9362cf3 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -220,10 +220,13 @@ class FmhaKernel: # ===== TMA LOAD warp ===== if warp_idx == self.tma_warp_id: if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): loop over k_sub tiles + # K sub-tiling path (hd>256): Python for loop (unrolled at trace time). + # cutlass.range(unroll=1) creates runtime loops that the MLIR optimizer + # struggles with (45+ min compilation). Python range unrolls at trace time + # but produces simpler IR since each iteration is a flat copy. qp.reset() kvp.reset() - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + for k_sub in range(self.n_k_sub_tiles): qh = qp.acquire_and_advance() cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) kvh = kvp.acquire_and_advance() @@ -250,13 +253,11 @@ class FmhaKernel: if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): loop over k_sub tiles. - # ACCUMULATE=False for the very first GEMM (k_sub=0, kb=0), - # then True for all subsequent GEMMs. + # K sub-tiling path (hd>256): Python for loop (unrolled at trace time) qc.reset() kvc.reset() qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + for k_sub in range(self.n_k_sub_tiles): qh = qc.wait_and_advance(); qh.release() kvh = kvc.wait_and_advance() for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):