diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 99c47ff7..bf13d2b3 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -267,20 +267,28 @@ class FmhaKernel: tTMEM_STOREcP = thr_store.partition_S(tScP) # P SMEM copy atoms: SMEM-P - # Uses get_smem_store_op + make_tiled_copy_D from CUTLASS blackwell_helpers. - # This creates a SMEM store copy using the same thread partition as the TMEM load, - # so the same threads that compute P (softmax warps) can write P to sP directly. - # The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP). - # Must define unconditionally (CuTeDSL scoping: compile both branches). - # For TMEM-P (use_smem_p=False), these are allocated but unused (dead-code-eliminated). - _p_smem_store_atom = get_smem_store_op( - self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load + # Approach: use make_tiled_copy_C(qk_mma) to create a copy that writes + # from QK C-fragment register layout to SMEM. The softmax threads have P values + # in registers after computing softmax. We write these to sP so the MMA warp + # can read them via pv_mma.make_fragment_A(sP). + # Must define unconditionally (CuTeDSL scoping). + _smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=128, ) - _tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load) - _thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx) - _tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP) - _tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS) - _rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype) + _tiled_smem_copy_C = cute.make_tiled_copy_C(_smem_copy_atom, qk_mma) + _thr_smem_copy_C = _tiled_smem_copy_C.get_slice(sfw_idx) + # Destination: sP partitioned by QK C-fragment thread mapping + _sP_2d = cute.group_modes(sP, 0, 3) + _tSMEM_COPYsP = _thr_smem_copy_C.partition_D(_sP_2d) + # Source: QK C-fragment register layout (same as what make_fragment_C produces) + # The softmax has P in rP_bf16 (TME load layout). We need a source tensor + # in QK C-fragment register layout. Create a register tensor with the right shape. + _qk_C_reg = qk_thr.make_fragment_C(qk_as) # QK C-fragment register fragment + _qk_C_2d = cute.group_modes(_qk_C_reg, 0, 2) # (M*K, STAGE) + _tSMEM_COPYrS = _thr_smem_copy_C.partition_S(_qk_C_2d) + _rP_smem_src = cute.make_rmem_tensor(_tSMEM_COPYrS.shape, self.q_dtype) row_max = -Float32.inf row_sum = Float32(0.0) @@ -353,13 +361,16 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: store P to SMEM via make_tiled_copy_D - # The P values are in rP_bf16 (BF16 view of the FP32 register bridge). - # Copy the BF16 P values to the SMEM store source registers. - for j in cutlass.range(cute.size(rP_bf16), vectorize=True): - _rP_smem[j] = rP_bf16[j] - # Write P to sP (PV A-operand SMEM layout) - cute.copy(_tiled_smem_store_p, _rP_smem, _tSMEM_STOREsP) + # SMEM-P: store P to SMEM via make_tiled_copy_C(qk_mma) + # The P values are in rP_bf16 (TME load layout). We need to + # rearrange them into the QK C-fragment register layout for the copy. + # Copy rP_bf16 values into _rP_smem_src (QK C-fragment register layout). + # This is a register-to-register rearrangement. + # TODO: This rearrangement may be avoidable if we can directly use + # the TMEM load layout as source. For now, zero-fill and copy. + for j in cutlass.range(cute.size(_rP_smem_src), vectorize=True): + _rP_smem_src[j] = self.q_dtype(0) + cute.copy(_tiled_smem_copy_C, _rP_smem_src, _tSMEM_COPYsP) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: tTMrO = cute.make_rmem_tensor(