SMEM-P: add debug_swap_mn flag to test swapped coordinate mapping

This commit is contained in:
2026-05-23 20:10:39 +00:00
parent aac08b57c4
commit afa4b8a746

View File

@@ -23,6 +23,7 @@ class FmhaKernel:
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.debug_p_one = True # DEBUG: write constant P=1.0 to verify mapping
self.debug_swap_mn = True # DEBUG: try swapping m and n0 in coordinate mapping
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
@@ -388,7 +389,10 @@ class FmhaKernel:
n0 = n_local % 16
n1 = (n_local // 16) % 4
n2 = n_local // 64
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
if self.debug_swap_mn:
pv_coord = ((n0, m_local), 0, (n1, n2), 0)
else:
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
# Write normalized P value
p_val_bf16 = p_val.to(self.q_dtype)