Debug: print P values

This commit is contained in:
2026-05-28 14:44:09 +00:00
parent 3d15f5bb21
commit 6f5be8a4e4

View File

@@ -94,6 +94,7 @@ test_fmha_v5(const bf16_t* q, const bf16_t* k, const bf16_t* v,
}
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
if (lane == 0) { printf("P[0..7]: "); for(int j=0;j<8;j++) printf("%.4f ", s_vals[j]); printf("\n"); }
// Write P values to shared memory
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
}