diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 346c3fe6..97f55969 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -23,6 +23,7 @@ class FmhaKernel: self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.debug_p_one = True # DEBUG: write constant P=1.0 to verify mapping + self.debug_swap_mn = True # DEBUG: try swapping m and n0 in coordinate mapping self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -388,7 +389,10 @@ class FmhaKernel: n0 = n_local % 16 n1 = (n_local // 16) % 4 n2 = n_local // 64 - pv_coord = ((m_local, n0), 0, (n1, n2), 0) + if self.debug_swap_mn: + pv_coord = ((n0, m_local), 0, (n1, n2), 0) + else: + pv_coord = ((m_local, n0), 0, (n1, n2), 0) # Write normalized P value p_val_bf16 = p_val.to(self.q_dtype)