diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index f09ec31f..ffc85a37 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -367,20 +367,18 @@ class FmhaKernel: cute.arch.fence_view_async_tmem_store() else: # SMEM-P: write P to sP using TiledCopy derived from QK MMA. - # - # make_tiled_copy_C with qk_mma gives threads partitioned by the - # QK C-fragment (same as TMEM load). Source: rP_bf16 (registers). - # Destination: sP (PV A-operand SMEM layout). - _smem_p_store_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.q_dtype, - num_bits_per_copy=16, - ) - _tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma) - _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) - _tRS_rP = _thr_smem_p.partition_S(rP_bf16) - _tRS_sP = _thr_smem_p.partition_D(_sP_nostage) - cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) + # SMEM-P: write P to sP using coordinate-indexed store. + # Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates + # and maps them to sP's swizzled layout. + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: for i in range(n_corr_tiles):