From c7f613644f18769be8a8a118ed1ecbd7f2933582 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:11:50 +0000 Subject: [PATCH] SMEM-P: fix scoping error, disable debug_p_one, enable debug_swap_mn --- dsv4/kernels/attention/fmha.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 97f55969..4463f894 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -22,7 +22,7 @@ class FmhaKernel: self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) - self.debug_p_one = True # DEBUG: write constant P=1.0 to verify mapping + self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping self.debug_swap_mn = True # DEBUG: try swapping m and n0 in coordinate mapping self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 @@ -389,10 +389,10 @@ class FmhaKernel: n0 = n_local % 16 n1 = (n_local // 16) % 4 n2 = n_local // 64 + # Default mapping + pv_coord = ((m_local, n0), 0, (n1, n2), 0) if self.debug_swap_mn: pv_coord = ((n0, m_local), 0, (n1, n2), 0) - else: - pv_coord = ((m_local, n0), 0, (n1, n2), 0) # Write normalized P value p_val_bf16 = p_val.to(self.q_dtype)