Debug: override P with uniform 1/128

This commit is contained in:
2026-05-28 14:46:21 +00:00
parent af93c283c7
commit 75bdcbf728

View File

@@ -95,6 +95,8 @@ test_fmha_v5(const bf16_t* q, const bf16_t* k, const bf16_t* v,
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
if (lane == 0) { printf("P[0..7]: "); for(int j=0;j<8;j++) printf("%.4f ", s_vals[j]); printf("\n"); }
// DEBUG: Override P with known values
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] = 1.0f / SK; // Uniform P
// Write P values to shared memory
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
}