From 8d48d6d543b72741b2d63eee384f67c3cd906ad6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:06:23 +0000 Subject: [PATCH] SMEM-P: try using PV A-operand layout directly for TMEM-P --- dsv4/kernels/attention/fmha.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 87783e49..8f8711c8 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -253,15 +253,17 @@ class FmhaKernel: # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) - tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) + # Try using PV A-operand layout directly (p_tmem_s.outer) instead of composed layout + # tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + # Use PV A-operand TMEM layout + tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), p_tmem_s.outer) tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) thr_store = tiled_tmem_store.get_slice(sfw_idx) tTMEM_STOREtP = thr_store.partition_D(tStP0) - tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tScP = cute.make_tensor(tScS.iterator, tScP_layout) + # tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + # Use same PV A-operand layout for coordinate tensor + tScP = cute.make_tensor(tScS.iterator, p_tmem_s.outer) tTMEM_STOREcP = thr_store.partition_S(tScP) # Manual SMEM addressing for P (CUTLASS LLM guidance)