Fix reference to also use uniform P

This commit is contained in:
2026-05-28 14:47:10 +00:00
parent 75bdcbf728
commit 9a3b43c42b

View File

@@ -152,7 +152,7 @@ test_fmha_v5(const bf16_t* q, const bf16_t* k, const bf16_t* v,
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
for (int j=0;j<SK;j++) s[j] /= sm;
for (int j=0;j<SK;j++) s[j] = 1.0f / SK; // Uniform P (matching override)
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);