From 2cc9786491e14c282a98ef2a4acca41111a7d8b1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 05:06:45 +0000 Subject: [PATCH] Fix p_tmem_s: use ComposedLayout from make_smem_layout_a, pass as kernel arg --- dsv4/kernels/attention/fmha.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 5f737f53..456ae698 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -90,8 +90,8 @@ class FmhaKernel: if self.num_tmem_alloc_cols > 512: print(f"⚠️ TMEM BUDGET: {self.num_tmem_alloc_cols} cols (hd={hd})") - # P TMEM alias (PV A-operand viewed as TMEM for partition mapping) - self.p_tmem_s = tStS # reuses QK C-fragment TMEM layout for P partition + # P TMEM layout (PV A-operand SMEM layout — used to alias QK C-fragment in TMEM) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) # TMA bytes cta = cute.size(qk_mma.thr_id.shape) @@ -143,7 +143,7 @@ class FmhaKernel: self._kernel( qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, self.cluster_layout_vmnk, - self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_smem_s, self.c_smem_s, + self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile, ).launch( grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream, @@ -152,7 +152,7 @@ class FmhaKernel: @cute.kernel def _kernel( self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, - cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, c_smem_s, epi_tile, + cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, p_tmem_s, c_smem_s, epi_tile, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx, _, _ = cute.arch.thread_idx() @@ -253,7 +253,7 @@ class FmhaKernel: # ── TMEM-P path: PV A-operand from TMEM ── if not use_smem_p: - tP = cute.make_tensor(tStS.iterator, self.p_tmem_s.outer) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) tOrP_base = pv_thr.make_fragment_A(tP) tOrP = tOrP_base[(None, None, None, 0)] tOrP0 = cute.make_tensor(