SMEM-P: add debug_permute flag for coordinate permutation testing
This commit is contained in:
@@ -24,6 +24,7 @@ class FmhaKernel:
|
||||
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
|
||||
self.debug_p_one = False # DEBUG: write constant P=1.0 to verify mapping
|
||||
self.debug_swap_mn = False # DEBUG: try swapping m and n0 in coordinate mapping
|
||||
self.debug_permute = 0 # DEBUG: try different coordinate permutations
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
@@ -389,10 +390,25 @@ class FmhaKernel:
|
||||
n0 = n_local % 16
|
||||
n1 = (n_local // 16) % 4
|
||||
n2 = n_local // 64
|
||||
# Default mapping
|
||||
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
|
||||
if self.debug_swap_mn:
|
||||
pv_coord = ((n0, m_local), 0, (n1, n2), 0)
|
||||
|
||||
# DEBUG: Try different permutations to find correct mapping
|
||||
# coords = [m_local, n0, n1, n2]
|
||||
# Permutation 0: (m, n0, n1, n2) original
|
||||
# Permutation 1: (n0, m, n1, n2) swap m↔n0
|
||||
# Permutation 2: (m, n1, n0, n2) swap n0↔n1
|
||||
# Permutation 3: (m, n0, n2, n1) swap n1↔n2
|
||||
if self.debug_permute == 0:
|
||||
a,b,c,d = m_local, n0, n1, n2
|
||||
elif self.debug_permute == 1:
|
||||
a,b,c,d = n0, m_local, n1, n2
|
||||
elif self.debug_permute == 2:
|
||||
a,b,c,d = m_local, n1, n0, n2
|
||||
elif self.debug_permute == 3:
|
||||
a,b,c,d = m_local, n0, n2, n1
|
||||
else:
|
||||
a,b,c,d = m_local, n0, n1, n2
|
||||
|
||||
pv_coord = ((a, b), 0, (c, d), 0)
|
||||
|
||||
# Write normalized P value
|
||||
p_val_bf16 = p_val.to(self.q_dtype)
|
||||
|
||||
Reference in New Issue
Block a user