diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 89ece768..a9008213 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -341,26 +341,22 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # ── TMEM-P: P store setup (register bridge) ── - if not use_smem_p: - p_cols_fp32 = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width - tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout) - tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype, - ) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) - thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0) - tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tScP = cute.make_tensor(tScS.iterator, tScP_layout) - tTMEM_STOREcP = thr_store.partition_S(tScP) + # ── P store setup (always define both paths — CuTeDSL scoping) ── + # TMEM-P: register bridge for P → TMEM + p_cols_fp32 = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width + tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype, + ) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) + thr_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtP = thr_store.partition_D(tStP0) + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tTMEM_STOREcP = thr_store.partition_S(tScP) - # ── SMEM-P: P → SMEM copy setup (TODO: proper QK→PV partition remap) ── - if use_smem_p: - # TODO: make_tiled_copy_C(store_atom, qk_mma) to partition threads by QK's C-fragment - # For now, zero sP as a stub — PV will read garbage/zero - pass + # SMEM-P: TODO — make_tiled_copy_C(store_atom, qk_mma) for QK→PV partition remap # ── O rescale / normalization setup (correction_rescale pattern from Stage C) ── corr_tile_size = 16