From afa4b8a746e2b4dadfe833cd3948b9138636e325 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:10:39 +0000 Subject: [PATCH] SMEM-P: add debug_swap_mn flag to test swapped coordinate mapping --- dsv4/kernels/attention/fmha.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 346c3fe6..97f55969 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -23,6 +23,7 @@ class FmhaKernel: self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.debug_p_one = True # DEBUG: write constant P=1.0 to verify mapping + self.debug_swap_mn = True # DEBUG: try swapping m and n0 in coordinate mapping self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -388,7 +389,10 @@ class FmhaKernel: n0 = n_local % 16 n1 = (n_local // 16) % 4 n2 = n_local // 64 - pv_coord = ((m_local, n0), 0, (n1, n2), 0) + if self.debug_swap_mn: + pv_coord = ((n0, m_local), 0, (n1, n2), 0) + else: + pv_coord = ((m_local, n0), 0, (n1, n2), 0) # Write normalized P value p_val_bf16 = p_val.to(self.q_dtype)