P5: add single-tile merge comparison to multitile test

This commit is contained in:
2026-05-30 09:06:57 +00:00
parent d424ccbcc1
commit 032cb4c7b2

View File

@@ -149,7 +149,80 @@ int main() {
for (int d = 0; d < 5; d++) printf(" %.4f", h_o_ref[d]);
printf("\n");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse);
// Also test: 2 separate single-tile launches + Python merge
bf16_t *d_k1, *d_v1, *d_k2, *d_v2;
bf16_t *d_o1, *d_o2;
float *d_lse1, *d_lse2;
cudaMalloc(&d_k1, 128 * HD * 2);
cudaMalloc(&d_k2, 128 * HD * 2);
cudaMalloc(&d_v1, HD * 128 * 2);
cudaMalloc(&d_v2, HD * 128 * 2);
cudaMalloc(&d_o1, HD * 2);
cudaMalloc(&d_o2, HD * 2);
cudaMalloc(&d_lse1, 4);
cudaMalloc(&d_lse2, 4);
cudaMemcpy(d_k1, d_k, 128 * HD * 2, cudaMemcpyDeviceToDevice);
cudaMemcpy(d_k2, d_k + 128 * HD, 128 * HD * 2, cudaMemcpyDeviceToDevice);
cudaMemcpy(d_v1, d_v, HD * 128 * 2, cudaMemcpyDeviceToDevice);
cudaMemcpy(d_v2, d_v + HD * 128, HD * 128 * 2, cudaMemcpyDeviceToDevice);
// Run single-tile on first half
FmhaParams p1 = params;
p1.k = d_k1; p1.v = d_v1; p1.o = d_o1; p1.lse = d_lse1;
p1.s_k = 128; p1.n_kv_tiles = 1;
fmha_6warp_multihead_kernel<HD, 128><<<grid, block, smem>>>(p1);
// Run single-tile on second half
FmhaParams p2 = params;
p2.k = d_k2; p2.v = d_v2; p2.o = d_o2; p2.lse = d_lse2;
p2.s_k = 128; p2.n_kv_tiles = 1;
fmha_6warp_multihead_kernel<HD, 128><<<grid, block, smem>>>(p2);
cudaDeviceSynchronize();
// Read single-tile results
bf16_t h_o1[HD], h_o2[HD];
float h_lse1, h_lse2;
cudaMemcpy(h_o1, d_o1, HD * 2, cudaMemcpyDeviceToHost);
cudaMemcpy(h_o2, d_o2, HD * 2, cudaMemcpyDeviceToHost);
cudaMemcpy(&h_lse1, d_lse1, 4, cudaMemcpyDeviceToHost);
cudaMemcpy(&h_lse2, d_lse2, 4, cudaMemcpyDeviceToHost);
// Python merge: O = (exp(lse1)*O1 + exp(lse2)*O2) / (exp(lse1) + exp(lse2))
float e1 = expf(h_lse1), e2 = expf(h_lse2);
float h_o_merge[HD];
for (int d = 0; d < HD; d++) {
float o1 = hbf16_to_f32(h_o1[d]) * e1;
float o2 = hbf16_to_f32(h_o2[d]) * e2;
h_o_merge[d] = (o1 + o2) / (e1 + e2);
}
// Compare merge vs reference
float cos_merge = 0, nm_a = 0, nm_b = 0;
for (int d = 0; d < HD; d++) {
cos_merge += h_o_ref[d] * h_o_merge[d];
nm_a += h_o_ref[d] * h_o_ref[d];
nm_b += h_o_merge[d] * h_o_merge[d];
}
cos_merge /= sqrtf(nm_a * nm_b + 1e-30f);
printf(" Single-tile merge: cos=%.6f lse1=%.4f lse2=%.4f\n", cos_merge, h_lse1, h_lse2);
// Compare in-kernel multi-tile vs Python merge
float cos_vs_merge = 0; nm_a = 0; nm_b = 0;
for (int d = 0; d < HD; d++) {
float a = h_o_merge[d];
float b = hbf16_to_f32(h_o[d]);
cos_vs_merge += a * b;
nm_a += a * a;
nm_b += b * b;
}
cos_vs_merge /= sqrtf(nm_a * nm_b + 1e-30f);
printf(" In-kernel vs Python merge: cos=%.6f\n", cos_vs_merge);
cudaFree(d_k1); cudaFree(d_k2); cudaFree(d_v1); cudaFree(d_v2);
cudaFree(d_o1); cudaFree(d_o2); cudaFree(d_lse1); cudaFree(d_lse2);
if (cos >= 0.999990) { printf("PASS\n"); return 0; }
else { printf("FAIL (cos < 0.999990)\n"); return 1; }