diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 7f9d8f01..57e95b82 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -9,6 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +from cutlass.utils.blackwell_helpers import get_smem_store_op import cuda.bindings.driver as cuda import cutlass.torch as ct import math @@ -265,10 +266,30 @@ class FmhaKernel: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # P SMEM copy atoms: SMEM-P (TBD) - # make_tiled_copy_C gives rank mismatch (QK C-fragment has 4 modes, - # PV A-operand SMEM has 3 modes). Need proper layout-aware copy. - # For now, SMEM-P path zero-fills sP. TMEM-P (hd<=64) works correctly. + # P SMEM copy atoms: SMEM-P + # Uses get_smem_store_op + make_tiled_copy_D from CUTLASS blackwell_helpers. + # This creates a SMEM store copy using the same thread partition as the TMEM load, + # so the same threads that compute P (softmax warps) can write P to sP directly. + # The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP). + if self.use_smem_p: + _p_smem_store_atom = get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load + ) + _tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load) + _thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx) + # Partition sP for the SMEM store (destination) + _sP_2d = cute.group_modes(sP, 0, 3) + _tSMEM_STOREsP = _thr_smem_store_p.partition_D(_sP_2d) + # Create a source register tensor for the SMEM store + # The source layout comes from the TMEM load's source partition (tTMEM_LOADtS) + # But we need a BF16 view of P, not the FP32 S view. + # The softmax writes P values (BF16) via the register bridge rP_bf16, + # which shares layout with rS (tTMEM_LOADrS layout). + # make_tiled_copy_D uses the same thread partition, so the source partition + # should match the TMEM load source partition. + _tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS) + # But rS is FP32. We need BF16. Create BF16 view of rP with same layout. + _rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype) row_max = -Float32.inf row_sum = Float32(0.0) @@ -341,12 +362,13 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: zero-fill sP stub (proper layout-aware copy TBD) - # make_tiled_copy_C gives rank mismatch (4 vs 3). - # Need proper P register→SMEM copy that respects QK C-fragment layout - # and PV A-operand SMEM layout. For now, TMEM-P (hd<=64) works. - for j in cutlass.range(cute.size(sP), vectorize=True): - sP[j] = self.q_dtype(0) + # SMEM-P: store P to SMEM via make_tiled_copy_D + # The P values are in rP_bf16 (BF16 view of the FP32 register bridge). + # Copy the BF16 P values to the SMEM store source registers. + for j in cutlass.range(cute.size(rP_bf16), vectorize=True): + _rP_smem[j] = rP_bf16[j] + # Write P to sP (PV A-operand SMEM layout) + cute.copy(_tiled_smem_store_p, _rP_smem, _tSMEM_STOREsP) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: tTMrO = cute.make_rmem_tensor(