SMEM-P: test permutation 4 (swap m↔n2)

This commit is contained in:
2026-05-23 20:20:07 +00:00
parent c7a299d7d9
commit 993ec32567

View File

@@ -24,7 +24,7 @@ class FmhaKernel:
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping
self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping
self.debug_permute = 0 # DEBUG: try different coordinate permutations (0=original,1=swap m↔n0)
self.debug_permute = 4 # DEBUG: try different coordinate permutations (4=swap m↔n2)
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
@@ -405,6 +405,10 @@ class FmhaKernel:
# Permutation 1: (n0, m, n1, n2) swap m↔n0
# Permutation 2: (m, n1, n0, n2) swap n0↔n1
# Permutation 3: (m, n0, n2, n1) swap n1↔n2
# Permutation 4: (n2, n0, n1, m) swap m↔n2
# Permutation 5: (n1, n0, m, n2) swap m↔n1
# Permutation 6: (n0, n1, n2, m) rotate right
# Permutation 7: (n2, n1, n0, m) reverse
if self.debug_permute == 0:
a,b,c,d = m_local, n0, n1, n2
elif self.debug_permute == 1:
@@ -413,6 +417,14 @@ class FmhaKernel:
a,b,c,d = m_local, n1, n0, n2
elif self.debug_permute == 3:
a,b,c,d = m_local, n0, n2, n1
elif self.debug_permute == 4:
a,b,c,d = n2, n0, n1, m_local
elif self.debug_permute == 5:
a,b,c,d = n1, n0, m_local, n2
elif self.debug_permute == 6:
a,b,c,d = n0, n1, n2, m_local
elif self.debug_permute == 7:
a,b,c,d = n2, n1, n0, m_local
else:
a,b,c,d = m_local, n0, n1, n2