From 7a4ff959bf6d7f3a263def7eefc7290980f0d704 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 14:22:45 +0000 Subject: [PATCH] D1.4: Use cutlass.range loop for k_sub (reduce IR), guard O rescale with const_expr(n_kv_tiles>1) --- dsv4/kernels/attention/fmha.py | 140 ++++++++++++++++----------------- 1 file changed, 68 insertions(+), 72 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index b2c002e7..20d02f73 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -220,19 +220,14 @@ class FmhaKernel: # ===== TMA LOAD warp ===== if warp_idx == self.tma_warp_id: if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): unrolled k_sub loads + # K sub-tiling path (hd=512): loop over k_sub tiles qp.reset() kvp.reset() - # k_sub=0: Load Q[0] and K[0] - qh0 = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh0.index)], tma_bar_ptr=qh0.barrier) - kvh0 = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, Int32(0))], tBsK[(None, kvh0.index)], tma_bar_ptr=kvh0.barrier) - # k_sub=1: Load Q[1] and K[1] - qh1 = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(1))], tAsQ[(None, qh1.index)], tma_bar_ptr=qh1.barrier) - kvh1 = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, Int32(1))], tBsK[(None, kvh1.index)], tma_bar_ptr=kvh1.barrier) + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + kvh = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) # Load V[0] kvh_v = kvp.acquire_and_advance() cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier) @@ -255,24 +250,20 @@ class FmhaKernel: if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): unrolled k_sub iterations - # k_sub=0: QK GEMM with ACCUMULATE=False - qh0 = qc.wait_and_advance(); qh0.release() - kvh0 = kvc.wait_and_advance() + # K sub-tiling path (hd=512): loop over k_sub tiles. + # ACCUMULATE=False for the very first GEMM (k_sub=0, kb=0), + # then True for all subsequent GEMMs. + qc.reset() + kvc.reset() qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh0.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - kvh0.release() - # k_sub=1: QK GEMM with ACCUMULATE=True - qh1 = qc.wait_and_advance(); qh1.release() - kvh1 = kvc.wait_and_advance() - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh1.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - kvh1.release() - # After both k_sub: S has full QK for this kt + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qc.wait_and_advance(); qh.release() + kvh = kvc.wait_and_advance() + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh.release() + # After all k_sub: S has full QK for this kt cute.arch.fence_view_async_tmem_store() softmax_done_bar.arrive() softmax_done_bar.arrive_and_wait() @@ -373,34 +364,38 @@ class FmhaKernel: scale_log2 = Float32(self.scale_softmax_log2) # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) + # Only needed when there are multiple KV tiles (O must be rescaled per-kt). + # With n_kv_tiles=1, no rescale is needed (kt is always 0). + # Define placeholder values unconditionally for CuTeDSL scoping. corr_tile_size = 16 - tOcO = pv_thr.partition_C(cS) - tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - tmem_load_o_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tmem_store_o_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) n_corr_tiles = self.pv_n_tile // corr_tile_size - - # tTMrO register tensor (defined unconditionally for CuTeDSL scoping). - # Used for O rescale (kt > 0) and O normalization (after loop). tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + (cute.make_layout((1,)), 1), self.acc_dtype ) + if const_expr(self.n_kv_tiles > 1): + tOcO = pv_thr.partition_C(cS) + tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + tmem_load_o_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) + tmem_store_o_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) + thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) + thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) + tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) + tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance() @@ -456,26 +451,27 @@ class FmhaKernel: k2 = k_coord // 64 _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") - if kt > 0: - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, - tTMEM_LOADtO.layout, - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, - tTMEM_STOREtO.layout, - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() + if const_expr(self.n_kv_tiles > 1): + if kt > 0: + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[k] = tTMrO_i[k] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() si_handle.release() softmax_done_bar.arrive()