test: add multi-head and batched prefill tests for multirow kernel

This commit is contained in:
2026-05-28 23:48:53 +00:00
parent ac8fa779e2
commit ca5cf0e517

View File

@@ -154,6 +154,7 @@ int main() {
printf("Multi-row FMHA test (HD=%d)\n", HD);
int ok = 1;
// Single-head, single-batch: T=1..128
ok &= test_single_T(1);
ok &= test_single_T(2);
ok &= test_single_T(4);
@@ -162,6 +163,14 @@ int main() {
ok &= test_single_T(32);
ok &= test_single_T(64);
ok &= test_single_T(128);
// Multi-head prefill
ok &= test_single_T(4, 4, 1); // 4 heads, T=4
ok &= test_single_T(16, 4, 1); // 4 heads, T=16
ok &= test_single_T(32, 4, 1); // 4 heads, T=32
ok &= test_single_T(64, 4, 1); // 4 heads, T=64
// Batched
ok &= test_single_T(1, 2, 2); // 2 heads, 2 batch, T=1
ok &= test_single_T(16, 2, 2); // 2 heads, 2 batch, T=16
printf("\n%s\n", ok ? "ALL PASSED" : "SOME FAILED");
return ok ? 0 : 1;