From eb9f4e553fb43019d9cfc0f2ab0f8f68829296dd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:42:20 +0000 Subject: [PATCH] feat: SMEM-P using make_tiled_copy_tv + logical sP view --- dsv4/kernels/attention/fmha.py | 41 ++++++++++++++++------------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index d8c667c6..e4f13846 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,37 +366,34 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using make_cotiled_copy. - # Build atom_layout_tv: (tid, vid) -> sP_addr. - # Each softmax thread (128 total) owns one row of the 128×128 P matrix. - # Within a row, 128 values decompose as (k0, k1, k2) = (16, 4, 2) - # matching sP's subtiled layout strides (1, 16, 8192). + # SMEM-P: write P to sP using make_tiled_copy_tv. + # Strategy: Create a R→S copy whose TV layout maps softmax + # thread/value ownership to sP's address space. + # + # Key assumption: 128 softmax threads each own one row of the + # 128×128 P matrix. Within each row, values are in sP's + # subtiled format: (16, 4, 2) with strides (1, 16, 8192). # Thread stride = 64 (one row in sP's layout). # - # atom_layout_tv(tid, k0, k1, k2) = 64*tid + k0 + 16*k1 + 8192*k2 - # sP_addr(m, k0, k1, k2) = 64*m + k0 + 16*k1 + 8192*k2 - # These are the SAME function — tid maps to m naturally. + # Using make_tiled_copy_tv with thread layout and value layout: + # thr_layout: (128,) with stride (1,) — 128 threads + # val_layout: (16, 4, 2) with stride (1, 16, 8192) — 128 values _smem_p_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.q_dtype, num_bits_per_copy=16, ) - _smem_p_atom_layout_tv = cute.make_layout( - (128, (16, 4, 2)), - stride=(64, (1, 16, 8192)), - ) - # Build sP data layout (coalesced, no swizzle, no stage dim) - # sP_outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) - # Coalesced: ((128,16,4,2),(64,1,16,8192)) - _smem_p_data_layout = cute.make_layout( - (128, 16, 4, 2), - stride=(64, 1, 16, 8192), - ) - _tiled_smem_p = cute.make_cotiled_copy( - _smem_p_atom, _smem_p_atom_layout_tv, _smem_p_data_layout, + _smem_p_thr_layout = cute.make_layout((128,), stride=(1,)) + _smem_p_val_layout = cute.make_layout((16, 4, 2), stride=(1, 16, 8192)) + _tiled_smem_p = cute.make_tiled_copy_tv( + _smem_p_atom, _smem_p_thr_layout, _smem_p_val_layout, ) _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) - _tRS_sP = _thr_smem_p.partition_D(_sP_nostage) + # Create a logical (non-swizzled) view of sP for partitioning + _sP_logical = cute.make_tensor(_sP_nostage.iterator, cute.make_layout( + (128, 16, 4, 2), stride=(64, 1, 16, 8192) + )) + _tRS_sP = _thr_smem_p.partition_D(_sP_logical) _tRS_rP = _tiled_smem_p.retile(rP_bf16) cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) cute.arch.fence("shared", space="cta")