SMEM-P: add crd2idx debug attempt

This commit is contained in:
2026-05-23 20:04:28 +00:00
parent 8879f0b701
commit db9d9b09d2

View File

@@ -382,12 +382,19 @@ class FmhaKernel:
# DEBUG: Print first few coordinates to verify mapping
if self.use_smem_p and k < 2 and j < 2:
print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}")
# Try to compute offset using crd2idx
try:
offset = cute.crd2idx(pv_coord, sP.layout)
print(f"[SMEM-P DEBUG] offset = {offset}")
except:
print(f"[SMEM-P DEBUG] crd2idx not available")
# DEBUG: Also write pattern based on fragment indices (k,j)
# If coordinates wrong, this pattern might work better
# pattern_val = Float32(k) + Float32(j) * Float32(32.0)
# pattern_bf16 = pattern_val.to(self.q_dtype)
# sP[pv_coord] = pattern_bf16
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
p_val_bf16 = pattern_val.to(self.q_dtype)
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
sP[pv_coord] = p_val_bf16 # Tensor indexing
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()