From b67668d2bdb7897c9d0161832bfee557dd3d376f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:43:12 +0000 Subject: [PATCH] feat: SMEM-P with make_tiled_copy_tv + manual fill --- dsv4/kernels/attention/fmha.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index e4f13846..188179e7 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -375,9 +375,8 @@ class FmhaKernel: # subtiled format: (16, 4, 2) with strides (1, 16, 8192). # Thread stride = 64 (one row in sP's layout). # - # Using make_tiled_copy_tv with thread layout and value layout: - # thr_layout: (128,) with stride (1,) — 128 threads - # val_layout: (16, 4, 2) with stride (1, 16, 8192) — 128 values + # We create the source register tensor in the copy's value order + # and fill it from rP_bf16 during the softmax loop. _smem_p_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.q_dtype, @@ -394,7 +393,24 @@ class FmhaKernel: (128, 16, 4, 2), stride=(64, 1, 16, 8192) )) _tRS_sP = _thr_smem_p.partition_D(_sP_logical) - _tRS_rP = _tiled_smem_p.retile(rP_bf16) + # Create source register tensor in copy's value order + _tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype) + # Fill _tRS_rP from rP_bf16 using coordinate-based mapping + # Each thread owns row sfw_idx of the P matrix. + # rP_bf16 is in TMEM-load register order: ((32,1),4,1,1) + # _tRS_rP is in copy value order: (16, 4, 2) per thread + # For thread i, rP_bf16[(j0,0),j1,0,0] corresponds to P[i, j0+32*j1] + # And _tRS_rP[k0,k1,k2] corresponds to P[i, k0+16*k1+64*k2] + # So k0+16*k1+64*k2 = j0+32*j1, which means: + # j0 = (k0+16*k1+64*k2) % 32 + # j1 = (k0+16*k1+64*k2) // 32 + for k0 in range(16): + for k1 in range(4): + for k2 in range(2): + flat_k = k0 + 16 * k1 + 64 * k2 # k index in [0,128) + j0 = flat_k % 32 + j1 = flat_k // 32 + _tRS_rP[k0, k1, k2] = rP_bf16[(j0, 0), j1, 0, 0] cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) cute.arch.fence("shared", space="cta") if kt > 0: