test: add multi-head and batched prefill tests for multirow kernel
This commit is contained in:
@@ -154,6 +154,7 @@ int main() {
|
||||
printf("Multi-row FMHA test (HD=%d)\n", HD);
|
||||
|
||||
int ok = 1;
|
||||
// Single-head, single-batch: T=1..128
|
||||
ok &= test_single_T(1);
|
||||
ok &= test_single_T(2);
|
||||
ok &= test_single_T(4);
|
||||
@@ -162,6 +163,14 @@ int main() {
|
||||
ok &= test_single_T(32);
|
||||
ok &= test_single_T(64);
|
||||
ok &= test_single_T(128);
|
||||
// Multi-head prefill
|
||||
ok &= test_single_T(4, 4, 1); // 4 heads, T=4
|
||||
ok &= test_single_T(16, 4, 1); // 4 heads, T=16
|
||||
ok &= test_single_T(32, 4, 1); // 4 heads, T=32
|
||||
ok &= test_single_T(64, 4, 1); // 4 heads, T=64
|
||||
// Batched
|
||||
ok &= test_single_T(1, 2, 2); // 2 heads, 2 batch, T=1
|
||||
ok &= test_single_T(16, 2, 2); // 2 heads, 2 batch, T=16
|
||||
|
||||
printf("\n%s\n", ok ? "ALL PASSED" : "SOME FAILED");
|
||||
return ok ? 0 : 1;
|
||||
|
||||
Reference in New Issue
Block a user