From 53bc54ed17a29f1ca0d1d5440ab2ee9fef3d55a7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 03:19:32 +0000 Subject: [PATCH] D1.5: Fix SMEM-P - use coordinate-indexed store (same proven pattern) --- dsv4/kernels/attention/fmha.py | 46 ++++++++++------------------------ 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index b8dcf549..39405779 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,39 +366,19 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using a TiledCopy derived from PV MMA's - # A-operand layout, but with softmax thread mapping. - # - # Strategy: Use the PV MMA's A-operand SMEM layout (which matches sP) - # and create a new TiledCopy with softmax thread/value layout. - # The softmax threads (128 total) each own one row of P. - # Within each row, values are in sP's subtiled format. - # - # We use pv_mma's make_tiled_copy_A to get the copy atom and tiling, - # then override the thread layout for the softmax threads. - _smem_p_tiled_copy = utils.sm100.make_tiled_copy_A( - cute.nvgpu.CopyUniversalOp(), - pv_mma, self.q_dtype, - 128, # tiler_mn - matches the P matrix tile size - ) - # Get the softmax thread's partition - _thr_smem_p = _smem_p_tiled_copy.get_slice(sfw_idx) - # Create a logical (non-swizzled) view of sP for partitioning - _sP_logical = cute.make_tensor(_sP_nostage.iterator, _sP_nostage.layout) - _tRS_sP = _thr_smem_p.partition_D(_sP_logical) - # Create source register tensor matching the copy's value order - _tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype) - # Fill _tRS_rP from rP_bf16. - # rP_bf16 is in TMEM-load order: ((32,1),4,1,1) with 128 values - # _tRS_rP is in copy value order. We need to map between them. - # For the copy, each thread should own P[thread_row, :]. - # rP_bf16[(j0,0),j1,0,0] = P[thread_row, j0+32*j1] - # We need to figure out the copy's value order for our thread. - # PRINT THE SHAPES to understand the mapping - # For now, fill with zeros as a baseline test - for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True): - _tRS_rP[v_idx] = BFloat16(0.0) - cute.copy(_smem_p_tiled_copy, _tRS_rP, _tRS_sP) + # SMEM-P: write P to sP using coordinate-indexed store. + # Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates + # and maps them to sP's swizzled layout. + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + cute.arch.fence_proxy("async.shared", space="cta") cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: for i in range(n_corr_tiles):