From fd54d657b2fbf021184f0fd2a84b3e299562a31a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:13:44 +0000 Subject: [PATCH] SMEM-P: add debug_permute flag for coordinate permutation testing --- dsv4/kernels/attention/fmha.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index fd19f43c..a2d0a43f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -24,6 +24,7 @@ class FmhaKernel: self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping + self.debug_permute = 0 # DEBUG: try different coordinate permutations self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -389,10 +390,25 @@ class FmhaKernel: n0 = n_local % 16 n1 = (n_local // 16) % 4 n2 = n_local // 64 - # Default mapping - pv_coord = ((m_local, n0), 0, (n1, n2), 0) - if self.debug_swap_mn: - pv_coord = ((n0, m_local), 0, (n1, n2), 0) + + # DEBUG: Try different permutations to find correct mapping + # coords = [m_local, n0, n1, n2] + # Permutation 0: (m, n0, n1, n2) original + # Permutation 1: (n0, m, n1, n2) swap m↔n0 + # Permutation 2: (m, n1, n0, n2) swap n0↔n1 + # Permutation 3: (m, n0, n2, n1) swap n1↔n2 + if self.debug_permute == 0: + a,b,c,d = m_local, n0, n1, n2 + elif self.debug_permute == 1: + a,b,c,d = n0, m_local, n1, n2 + elif self.debug_permute == 2: + a,b,c,d = m_local, n1, n0, n2 + elif self.debug_permute == 3: + a,b,c,d = m_local, n0, n2, n1 + else: + a,b,c,d = m_local, n0, n1, n2 + + pv_coord = ((a, b), 0, (c, d), 0) # Write normalized P value p_val_bf16 = p_val.to(self.q_dtype)