diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 154ccd1f..f09ec31f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,19 +366,21 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using coordinate-indexed store. - # Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates - # and maps them to sP's swizzled layout. - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] - k_coord = coord[1] - k0 = k_coord % 16 - k1 = (k_coord // 16) % 4 - k2 = k_coord // 64 - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] - cute.arch.fence_proxy("async.shared", space="cta") + # SMEM-P: write P to sP using TiledCopy derived from QK MMA. + # + # make_tiled_copy_C with qk_mma gives threads partitioned by the + # QK C-fragment (same as TMEM load). Source: rP_bf16 (registers). + # Destination: sP (PV A-operand SMEM layout). + _smem_p_store_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=16, + ) + _tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma) + _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) + _tRS_rP = _thr_smem_p.partition_S(rP_bf16) + _tRS_sP = _thr_smem_p.partition_D(_sP_nostage) + cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: for i in range(n_corr_tiles):