From 7357b1a866fb17d39b92ebfe5cc18e33674fb5d0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:44:20 +0000 Subject: [PATCH] fix: fence_proxy not fence --- dsv4/kernels/attention/fmha.py | 61 +++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 59354a64..56092241 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,21 +366,52 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using make_cotiled_copy. - # Build atom_layout_tv: (tid, vid) -> sP_addr. - # Each softmax thread (128 total) owns one row of the 128×128 P matrix. - # Within a row, 128 values decompose as (k0, k1, k2) = (16, 4, 2) - # matching sP's subtiled layout strides (1, 16, 8192). - # SMEM-P: write P to sP using coordinate-indexed store. - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] - k_coord = coord[1] - k0 = k_coord % 16 - k1 = (k_coord // 16) % 4 - k2 = k_coord // 64 - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + # SMEM-P: write P to sP using make_tiled_copy_tv. + # Strategy: Create a R→S copy whose TV layout maps softmax + # thread/value ownership to sP's address space. + # + # Key assumption: 128 softmax threads each own one row of the + # 128×128 P matrix. Within each row, values are in sP's + # subtiled format: (16, 4, 2) with strides (1, 16, 8192). + # Thread stride = 64 (one row in sP's layout). + # + # We create the source register tensor in the copy's value order + # and fill it from rP_bf16 during the softmax loop. + _smem_p_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=16, + ) + _smem_p_thr_layout = cute.make_layout((128,), stride=(1,)) + _smem_p_val_layout = cute.make_layout((16, 4, 2), stride=(1, 16, 8192)) + _tiled_smem_p = cute.make_tiled_copy_tv( + _smem_p_atom, _smem_p_thr_layout, _smem_p_val_layout, + ) + _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) + # Create a logical (non-swizzled) view of sP for partitioning + _sP_logical = cute.make_tensor(_sP_nostage.iterator, cute.make_layout( + (128, 16, 4, 2), stride=(64, 1, 16, 8192) + )) + _tRS_sP = _thr_smem_p.partition_D(_sP_logical) + # Create source register tensor in copy's value order + _tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype) + # Fill _tRS_rP from rP_bf16 using coordinate-based mapping + # Each thread owns row sfw_idx of the P matrix. + # rP_bf16 is in TMEM-load register order: ((32,1),4,1,1) + # _tRS_rP is in copy value order: (16, 4, 2) per thread + # For thread i, rP_bf16[(j0,0),j1,0,0] corresponds to P[i, j0+32*j1] + # And _tRS_rP[k0,k1,k2] corresponds to P[i, k0+16*k1+64*k2] + # So k0+16*k1+64*k2 = j0+32*j1, which means: + # j0 = (k0+16*k1+64*k2) % 32 + # j1 = (k0+16*k1+64*k2) // 32 + for k0 in range(16): + for k1 in range(4): + for k2 in range(2): + flat_k = k0 + 16 * k1 + 64 * k2 # k index in [0,128) + j0 = flat_k % 32 + j1 = flat_k // 32 + _tRS_rP[k0, k1, k2] = rP_bf16[(j0, 0), j1, 0, 0] + cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: for i in range(n_corr_tiles):