Fix reference to also use uniform P
This commit is contained in:
@@ -152,7 +152,7 @@ test_fmha_v5(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
|
||||
float sm = 0.0f;
|
||||
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
|
||||
for (int j=0;j<SK;j++) s[j] /= sm;
|
||||
for (int j=0;j<SK;j++) s[j] = 1.0f / SK; // Uniform P (matching override)
|
||||
for (int d=0;d<HD;d++) {
|
||||
float ov = 0.0f;
|
||||
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);
|
||||
|
||||
Reference in New Issue
Block a user