diff --git a/tests/unit/test_p5_multitile.cu b/tests/unit/test_p5_multitile.cu index 9f475be7..f8755f0e 100644 --- a/tests/unit/test_p5_multitile.cu +++ b/tests/unit/test_p5_multitile.cu @@ -149,7 +149,80 @@ int main() { for (int d = 0; d < 5; d++) printf(" %.4f", h_o_ref[d]); printf("\n"); - cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); + // Also test: 2 separate single-tile launches + Python merge + bf16_t *d_k1, *d_v1, *d_k2, *d_v2; + bf16_t *d_o1, *d_o2; + float *d_lse1, *d_lse2; + cudaMalloc(&d_k1, 128 * HD * 2); + cudaMalloc(&d_k2, 128 * HD * 2); + cudaMalloc(&d_v1, HD * 128 * 2); + cudaMalloc(&d_v2, HD * 128 * 2); + cudaMalloc(&d_o1, HD * 2); + cudaMalloc(&d_o2, HD * 2); + cudaMalloc(&d_lse1, 4); + cudaMalloc(&d_lse2, 4); + + cudaMemcpy(d_k1, d_k, 128 * HD * 2, cudaMemcpyDeviceToDevice); + cudaMemcpy(d_k2, d_k + 128 * HD, 128 * HD * 2, cudaMemcpyDeviceToDevice); + cudaMemcpy(d_v1, d_v, HD * 128 * 2, cudaMemcpyDeviceToDevice); + cudaMemcpy(d_v2, d_v + HD * 128, HD * 128 * 2, cudaMemcpyDeviceToDevice); + + // Run single-tile on first half + FmhaParams p1 = params; + p1.k = d_k1; p1.v = d_v1; p1.o = d_o1; p1.lse = d_lse1; + p1.s_k = 128; p1.n_kv_tiles = 1; + fmha_6warp_multihead_kernel<<>>(p1); + + // Run single-tile on second half + FmhaParams p2 = params; + p2.k = d_k2; p2.v = d_v2; p2.o = d_o2; p2.lse = d_lse2; + p2.s_k = 128; p2.n_kv_tiles = 1; + fmha_6warp_multihead_kernel<<>>(p2); + + cudaDeviceSynchronize(); + + // Read single-tile results + bf16_t h_o1[HD], h_o2[HD]; + float h_lse1, h_lse2; + cudaMemcpy(h_o1, d_o1, HD * 2, cudaMemcpyDeviceToHost); + cudaMemcpy(h_o2, d_o2, HD * 2, cudaMemcpyDeviceToHost); + cudaMemcpy(&h_lse1, d_lse1, 4, cudaMemcpyDeviceToHost); + cudaMemcpy(&h_lse2, d_lse2, 4, cudaMemcpyDeviceToHost); + + // Python merge: O = (exp(lse1)*O1 + exp(lse2)*O2) / (exp(lse1) + exp(lse2)) + float e1 = expf(h_lse1), e2 = expf(h_lse2); + float h_o_merge[HD]; + for (int d = 0; d < HD; d++) { + float o1 = hbf16_to_f32(h_o1[d]) * e1; + float o2 = hbf16_to_f32(h_o2[d]) * e2; + h_o_merge[d] = (o1 + o2) / (e1 + e2); + } + + // Compare merge vs reference + float cos_merge = 0, nm_a = 0, nm_b = 0; + for (int d = 0; d < HD; d++) { + cos_merge += h_o_ref[d] * h_o_merge[d]; + nm_a += h_o_ref[d] * h_o_ref[d]; + nm_b += h_o_merge[d] * h_o_merge[d]; + } + cos_merge /= sqrtf(nm_a * nm_b + 1e-30f); + + printf(" Single-tile merge: cos=%.6f lse1=%.4f lse2=%.4f\n", cos_merge, h_lse1, h_lse2); + + // Compare in-kernel multi-tile vs Python merge + float cos_vs_merge = 0; nm_a = 0; nm_b = 0; + for (int d = 0; d < HD; d++) { + float a = h_o_merge[d]; + float b = hbf16_to_f32(h_o[d]); + cos_vs_merge += a * b; + nm_a += a * a; + nm_b += b * b; + } + cos_vs_merge /= sqrtf(nm_a * nm_b + 1e-30f); + printf(" In-kernel vs Python merge: cos=%.6f\n", cos_vs_merge); + + cudaFree(d_k1); cudaFree(d_k2); cudaFree(d_v1); cudaFree(d_v2); + cudaFree(d_o1); cudaFree(d_o2); cudaFree(d_lse1); cudaFree(d_lse2); if (cos >= 0.999990) { printf("PASS\n"); return 0; } else { printf("FAIL (cos < 0.999990)\n"); return 1; }