From fce9a7f4be728291752d7dffd92075aa0a85dac5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:42:20 +0000 Subject: [PATCH 1/2] feat: SMEM-P using make_tiled_copy_tv + logical sP view --- dsv4/kernels/attention/fmha.py | 41 ++++++++++++++++------------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index d8c667c6..e4f13846 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,37 +366,34 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using make_cotiled_copy. - # Build atom_layout_tv: (tid, vid) -> sP_addr. - # Each softmax thread (128 total) owns one row of the 128×128 P matrix. - # Within a row, 128 values decompose as (k0, k1, k2) = (16, 4, 2) - # matching sP's subtiled layout strides (1, 16, 8192). + # SMEM-P: write P to sP using make_tiled_copy_tv. + # Strategy: Create a R→S copy whose TV layout maps softmax + # thread/value ownership to sP's address space. + # + # Key assumption: 128 softmax threads each own one row of the + # 128×128 P matrix. Within each row, values are in sP's + # subtiled format: (16, 4, 2) with strides (1, 16, 8192). # Thread stride = 64 (one row in sP's layout). # - # atom_layout_tv(tid, k0, k1, k2) = 64*tid + k0 + 16*k1 + 8192*k2 - # sP_addr(m, k0, k1, k2) = 64*m + k0 + 16*k1 + 8192*k2 - # These are the SAME function — tid maps to m naturally. + # Using make_tiled_copy_tv with thread layout and value layout: + # thr_layout: (128,) with stride (1,) — 128 threads + # val_layout: (16, 4, 2) with stride (1, 16, 8192) — 128 values _smem_p_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.q_dtype, num_bits_per_copy=16, ) - _smem_p_atom_layout_tv = cute.make_layout( - (128, (16, 4, 2)), - stride=(64, (1, 16, 8192)), - ) - # Build sP data layout (coalesced, no swizzle, no stage dim) - # sP_outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) - # Coalesced: ((128,16,4,2),(64,1,16,8192)) - _smem_p_data_layout = cute.make_layout( - (128, 16, 4, 2), - stride=(64, 1, 16, 8192), - ) - _tiled_smem_p = cute.make_cotiled_copy( - _smem_p_atom, _smem_p_atom_layout_tv, _smem_p_data_layout, + _smem_p_thr_layout = cute.make_layout((128,), stride=(1,)) + _smem_p_val_layout = cute.make_layout((16, 4, 2), stride=(1, 16, 8192)) + _tiled_smem_p = cute.make_tiled_copy_tv( + _smem_p_atom, _smem_p_thr_layout, _smem_p_val_layout, ) _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) - _tRS_sP = _thr_smem_p.partition_D(_sP_nostage) + # Create a logical (non-swizzled) view of sP for partitioning + _sP_logical = cute.make_tensor(_sP_nostage.iterator, cute.make_layout( + (128, 16, 4, 2), stride=(64, 1, 16, 8192) + )) + _tRS_sP = _thr_smem_p.partition_D(_sP_logical) _tRS_rP = _tiled_smem_p.retile(rP_bf16) cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) cute.arch.fence("shared", space="cta") From b67668d2bdb7897c9d0161832bfee557dd3d376f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:43:12 +0000 Subject: [PATCH 2/2] feat: SMEM-P with make_tiled_copy_tv + manual fill --- dsv4/kernels/attention/fmha.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index e4f13846..188179e7 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -375,9 +375,8 @@ class FmhaKernel: # subtiled format: (16, 4, 2) with strides (1, 16, 8192). # Thread stride = 64 (one row in sP's layout). # - # Using make_tiled_copy_tv with thread layout and value layout: - # thr_layout: (128,) with stride (1,) — 128 threads - # val_layout: (16, 4, 2) with stride (1, 16, 8192) — 128 values + # We create the source register tensor in the copy's value order + # and fill it from rP_bf16 during the softmax loop. _smem_p_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.q_dtype, @@ -394,7 +393,24 @@ class FmhaKernel: (128, 16, 4, 2), stride=(64, 1, 16, 8192) )) _tRS_sP = _thr_smem_p.partition_D(_sP_logical) - _tRS_rP = _tiled_smem_p.retile(rP_bf16) + # Create source register tensor in copy's value order + _tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype) + # Fill _tRS_rP from rP_bf16 using coordinate-based mapping + # Each thread owns row sfw_idx of the P matrix. + # rP_bf16 is in TMEM-load register order: ((32,1),4,1,1) + # _tRS_rP is in copy value order: (16, 4, 2) per thread + # For thread i, rP_bf16[(j0,0),j1,0,0] corresponds to P[i, j0+32*j1] + # And _tRS_rP[k0,k1,k2] corresponds to P[i, k0+16*k1+64*k2] + # So k0+16*k1+64*k2 = j0+32*j1, which means: + # j0 = (k0+16*k1+64*k2) % 32 + # j1 = (k0+16*k1+64*k2) // 32 + for k0 in range(16): + for k1 in range(4): + for k2 in range(2): + flat_k = k0 + 16 * k1 + 64 * k2 # k index in [0,128) + j0 = flat_k % 32 + j1 = flat_k // 32 + _tRS_rP[k0, k1, k2] = rP_bf16[(j0, 0), j1, 0, 0] cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) cute.arch.fence("shared", space="cta") if kt > 0: