Debug: override P with uniform 1/128
This commit is contained in:
@@ -95,6 +95,8 @@ test_fmha_v5(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
row_sum = wsum(row_sum);
|
||||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
|
||||
if (lane == 0) { printf("P[0..7]: "); for(int j=0;j<8;j++) printf("%.4f ", s_vals[j]); printf("\n"); }
|
||||
// DEBUG: Override P with known values
|
||||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] = 1.0f / SK; // Uniform P
|
||||
// Write P values to shared memory
|
||||
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user