test: verify ALL 128 rows × 8 cols match scalar reference

This commit is contained in:
2026-05-28 11:36:46 +00:00
parent 3c7d9d9303
commit 6f40fafa91

View File

@@ -80,14 +80,11 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
: "r"(addr));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Lane 0 writes first 8 output values for this warp's row
// Each lane gets values for its row (tid = wid*32 + lane)
// Each lane writes its row's 8 values (not just lane 0)
int out_row = wid * 32 + lane;
if (lane == 0 && n < 1) { // Only first 8 cols for debug
if (n < 1 && out_row < 128) { // First 8 cols only for debug
for (int c = 0; c < 8; c++) {
if (out_row < 128) {
s_out[out_row * 8 + c] = tmp[c] * scale; // Apply 1/sqrt(HD) scale
}
s_out[out_row * 8 + c] = tmp[c] * scale;
}
}
}
@@ -147,13 +144,22 @@ int main() {
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]);
printf("\n");
// Check ALL rows
float max_diff = 0.0f, max_val = 0.0f;
for (int c = 0; c < 8; c++) {
max_diff = fmaxf(max_diff, fabsf(h_s_out[0*8+c] - h_s_scalar[c]));
max_val = fmaxf(max_val, fabsf(h_s_scalar[c]));
int n_match = 0, n_total = 0;
for (int r = 0; r < SK; r++) {
for (int c = 0; c < 8; c++) {
float mma_val = h_s_out[r * 8 + c];
float ref_val = h_s_scalar[r]; // Scalar only has 1 val per row (all 8 cols should be same for this test)
float diff = fabsf(mma_val - ref_val);
if (diff < 0.01f * (fabsf(ref_val) + 1e-6f)) n_match++;
n_total++;
max_diff = fmaxf(max_diff, diff);
max_val = fmaxf(max_val, fabsf(ref_val));
}
}
float rel_err = max_val > 0 ? max_diff / max_val : max_diff;
printf("Row 0 max rel err: %.6f\n", rel_err);
printf("Max abs diff: %.6f, Max rel err: %.6f, Match: %d/%d\n", max_diff, rel_err, n_match, n_total);
// Print a few more rows
for (int r : {32, 64, 96}) {