From 8879f0b701f6dabca43a1c1ea9eb40337db28bbc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:03:22 +0000 Subject: [PATCH] SMEM-P: test pattern based on fragment indices (k,j) --- dsv4/kernels/attention/fmha.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 9a58e436..95015d29 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -372,13 +372,22 @@ class FmhaKernel: n2 = n_local // 64 pv_coord = ((m_local, n0), 0, (n1, n2), 0) - # Write actual P value (not test pattern) - p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) + # DEBUG: Write pattern based on fragment indices (k,j) + # If coordinates wrong, this pattern might work better + pattern_val = Float32(k) + Float32(j) * Float32(32.0) + p_val_bf16 = pattern_val.to(self.q_dtype) + # Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) sP[pv_coord] = p_val_bf16 # Tensor indexing # DEBUG: Print first few coordinates to verify mapping if self.use_smem_p and k < 2 and j < 2: print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}") + + # DEBUG: Also write pattern based on fragment indices (k,j) + # If coordinates wrong, this pattern might work better + # pattern_val = Float32(k) + Float32(j) * Float32(32.0) + # pattern_bf16 = pattern_val.to(self.q_dtype) + # sP[pv_coord] = pattern_bf16 row_sum = row_sum + tTMEM_LOADrS_frg[k, j] s_vec = tTMEM_LOADrS_frg[None, j].load()