diff --git a/tests/unit/test_fmha_6warp_multirow.cu b/tests/unit/test_fmha_6warp_multirow.cu index d9f14f81..31d9164f 100644 --- a/tests/unit/test_fmha_6warp_multirow.cu +++ b/tests/unit/test_fmha_6warp_multirow.cu @@ -154,6 +154,7 @@ int main() { printf("Multi-row FMHA test (HD=%d)\n", HD); int ok = 1; + // Single-head, single-batch: T=1..128 ok &= test_single_T(1); ok &= test_single_T(2); ok &= test_single_T(4); @@ -162,6 +163,14 @@ int main() { ok &= test_single_T(32); ok &= test_single_T(64); ok &= test_single_T(128); + // Multi-head prefill + ok &= test_single_T(4, 4, 1); // 4 heads, T=4 + ok &= test_single_T(16, 4, 1); // 4 heads, T=16 + ok &= test_single_T(32, 4, 1); // 4 heads, T=32 + ok &= test_single_T(64, 4, 1); // 4 heads, T=64 + // Batched + ok &= test_single_T(1, 2, 2); // 2 heads, 2 batch, T=1 + ok &= test_single_T(16, 2, 2); // 2 heads, 2 batch, T=16 printf("\n%s\n", ok ? "ALL PASSED" : "SOME FAILED"); return ok ? 0 : 1;