diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 64430b14..9f0e1059 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -24,7 +24,7 @@ class FmhaKernel: self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping - self.debug_permute = 0 # DEBUG: try different coordinate permutations (0=original,1=swap m↔n0) + self.debug_permute = 4 # DEBUG: try different coordinate permutations (4=swap m↔n2) self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -405,6 +405,10 @@ class FmhaKernel: # Permutation 1: (n0, m, n1, n2) swap m↔n0 # Permutation 2: (m, n1, n0, n2) swap n0↔n1 # Permutation 3: (m, n0, n2, n1) swap n1↔n2 + # Permutation 4: (n2, n0, n1, m) swap m↔n2 + # Permutation 5: (n1, n0, m, n2) swap m↔n1 + # Permutation 6: (n0, n1, n2, m) rotate right + # Permutation 7: (n2, n1, n0, m) reverse if self.debug_permute == 0: a,b,c,d = m_local, n0, n1, n2 elif self.debug_permute == 1: @@ -413,6 +417,14 @@ class FmhaKernel: a,b,c,d = m_local, n1, n0, n2 elif self.debug_permute == 3: a,b,c,d = m_local, n0, n2, n1 + elif self.debug_permute == 4: + a,b,c,d = n2, n0, n1, m_local + elif self.debug_permute == 5: + a,b,c,d = n1, n0, m_local, n2 + elif self.debug_permute == 6: + a,b,c,d = n0, n1, n2, m_local + elif self.debug_permute == 7: + a,b,c,d = n2, n1, n0, m_local else: a,b,c,d = m_local, n0, n1, n2