SMEM-P: add debug prints for coordinates

This commit is contained in:
2026-05-23 20:00:33 +00:00
parent 8edf2d434c
commit d2d0eec33a

View File

@@ -372,6 +372,10 @@ class FmhaKernel:
# Write actual P value (not test pattern)
p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
sP[pv_coord] = p_val_bf16 # Tensor indexing
# DEBUG: Print first few coordinates to verify mapping
if self.use_smem_p and k < 2 and j < 2:
print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}")
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()