test: verify ALL 128 rows × 8 cols match scalar reference
This commit is contained in:
@@ -80,14 +80,11 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
|
||||
: "r"(addr));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
|
||||
// Lane 0 writes first 8 output values for this warp's row
|
||||
// Each lane gets values for its row (tid = wid*32 + lane)
|
||||
// Each lane writes its row's 8 values (not just lane 0)
|
||||
int out_row = wid * 32 + lane;
|
||||
if (lane == 0 && n < 1) { // Only first 8 cols for debug
|
||||
if (n < 1 && out_row < 128) { // First 8 cols only for debug
|
||||
for (int c = 0; c < 8; c++) {
|
||||
if (out_row < 128) {
|
||||
s_out[out_row * 8 + c] = tmp[c] * scale; // Apply 1/sqrt(HD) scale
|
||||
}
|
||||
s_out[out_row * 8 + c] = tmp[c] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -147,13 +144,22 @@ int main() {
|
||||
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]);
|
||||
printf("\n");
|
||||
|
||||
// Check ALL rows
|
||||
float max_diff = 0.0f, max_val = 0.0f;
|
||||
for (int c = 0; c < 8; c++) {
|
||||
max_diff = fmaxf(max_diff, fabsf(h_s_out[0*8+c] - h_s_scalar[c]));
|
||||
max_val = fmaxf(max_val, fabsf(h_s_scalar[c]));
|
||||
int n_match = 0, n_total = 0;
|
||||
for (int r = 0; r < SK; r++) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
float mma_val = h_s_out[r * 8 + c];
|
||||
float ref_val = h_s_scalar[r]; // Scalar only has 1 val per row (all 8 cols should be same for this test)
|
||||
float diff = fabsf(mma_val - ref_val);
|
||||
if (diff < 0.01f * (fabsf(ref_val) + 1e-6f)) n_match++;
|
||||
n_total++;
|
||||
max_diff = fmaxf(max_diff, diff);
|
||||
max_val = fmaxf(max_val, fabsf(ref_val));
|
||||
}
|
||||
}
|
||||
float rel_err = max_val > 0 ? max_diff / max_val : max_diff;
|
||||
printf("Row 0 max rel err: %.6f\n", rel_err);
|
||||
printf("Max abs diff: %.6f, Max rel err: %.6f, Match: %d/%d\n", max_diff, rel_err, n_match, n_total);
|
||||
|
||||
// Print a few more rows
|
||||
for (int r : {32, 64, 96}) {
|
||||
|
||||
Reference in New Issue
Block a user