From 70c9d93d28919e5b79f6a292d455d5f4bee929e5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 03:23:48 +0000 Subject: [PATCH] feat: SMEM-P make_tiled_copy_C + zero-fill dest tensor --- dsv4/kernels/attention/fmha.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index f09ec31f..011ae945 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,11 +366,14 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using TiledCopy derived from QK MMA. + # SMEM-P: write P to sP using make_tiled_copy_C(qk_mma) with + # manual source→destination value mapping. # - # make_tiled_copy_C with qk_mma gives threads partitioned by the - # QK C-fragment (same as TMEM load). Source: rP_bf16 (registers). - # Destination: sP (PV A-operand SMEM layout). + # make_tiled_copy_C gives the right thread partition (QK C-fragment) + # but the source (rP_bf16) and destination (sP) have different ranks. + # Solution: use partition_D to get the sP partition, create a register + # tensor matching its shape, fill from rP_bf16 via coordinate mapping, + # then copy. _smem_p_store_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.q_dtype, @@ -378,8 +381,21 @@ class FmhaKernel: ) _tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma) _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) - _tRS_rP = _thr_smem_p.partition_S(rP_bf16) _tRS_sP = _thr_smem_p.partition_D(_sP_nostage) + # Create source register tensor matching destination shape + _tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype) + # Fill _tRS_rP from rP_bf16 using coordinate mapping. + # The copy's value layout indexes into the 128×128 P matrix. + # We use the TMEM-load coordinate tensor to map each value index + # to the corresponding (m, k) and then find the rP_bf16 element. + # Since both _tRS_rP and rP_bf16 represent the SAME P values for + # this thread, just in different layouts, we can use the coordinate + # tensor to establish the mapping. + # + # For now, zero-fill to test compilation and synchronization. + # Once the pipeline runs, we'll fill properly. + for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True): + _tRS_rP[v_idx] = BFloat16(0.0) cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: