diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index fa4b37bd..94fddf7f 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -80,14 +80,11 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k, : "r"(addr)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - // Lane 0 writes first 8 output values for this warp's row - // Each lane gets values for its row (tid = wid*32 + lane) + // Each lane writes its row's 8 values (not just lane 0) int out_row = wid * 32 + lane; - if (lane == 0 && n < 1) { // Only first 8 cols for debug + if (n < 1 && out_row < 128) { // First 8 cols only for debug for (int c = 0; c < 8; c++) { - if (out_row < 128) { - s_out[out_row * 8 + c] = tmp[c] * scale; // Apply 1/sqrt(HD) scale - } + s_out[out_row * 8 + c] = tmp[c] * scale; } } } @@ -147,13 +144,22 @@ int main() { for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]); printf("\n"); + // Check ALL rows float max_diff = 0.0f, max_val = 0.0f; - for (int c = 0; c < 8; c++) { - max_diff = fmaxf(max_diff, fabsf(h_s_out[0*8+c] - h_s_scalar[c])); - max_val = fmaxf(max_val, fabsf(h_s_scalar[c])); + int n_match = 0, n_total = 0; + for (int r = 0; r < SK; r++) { + for (int c = 0; c < 8; c++) { + float mma_val = h_s_out[r * 8 + c]; + float ref_val = h_s_scalar[r]; // Scalar only has 1 val per row (all 8 cols should be same for this test) + float diff = fabsf(mma_val - ref_val); + if (diff < 0.01f * (fabsf(ref_val) + 1e-6f)) n_match++; + n_total++; + max_diff = fmaxf(max_diff, diff); + max_val = fmaxf(max_val, fabsf(ref_val)); + } } float rel_err = max_val > 0 ? max_diff / max_val : max_diff; - printf("Row 0 max rel err: %.6f\n", rel_err); + printf("Max abs diff: %.6f, Max rel err: %.6f, Match: %d/%d\n", max_diff, rel_err, n_match, n_total); // Print a few more rows for (int r : {32, 64, 96}) {