test: full matrix for D1.5 multirow multitile
This commit is contained in:
@@ -205,18 +205,24 @@ int main() {
|
||||
printf("START: test_fmha_6warp_tma_multirow_multitile HD=%d\n", HD);
|
||||
int total_fail = 0;
|
||||
|
||||
// Just the most basic test first
|
||||
total_fail += test_single(1, 128);
|
||||
fflush(stdout);
|
||||
printf("\n=== 6-warp TMA FMHA multi-row multi-tile HD=%d ===\n", HD);
|
||||
|
||||
if (total_fail == 0) {
|
||||
total_fail += test_single(4, 128);
|
||||
fflush(stdout);
|
||||
total_fail += test_single(1, 256);
|
||||
fflush(stdout);
|
||||
// Single KV tile (s_k=128, baseline)
|
||||
for (int T : {1, 4, 32, 128}) {
|
||||
total_fail += test_single(T, 128);
|
||||
}
|
||||
|
||||
// Multi-tile KV — the whole point of D1.5
|
||||
for (int s_k : {256, 384, 512}) {
|
||||
for (int T : {1, 4, 32, 128}) {
|
||||
total_fail += test_single(T, s_k);
|
||||
}
|
||||
}
|
||||
|
||||
// Multi-head + batch
|
||||
total_fail += test_single(4, 256, 4, 1); // n_h=4
|
||||
total_fail += test_single(4, 256, 2, 2); // n_h=2, batch=2
|
||||
|
||||
printf("\nOverall: %s\n", total_fail == 0 ? "ALL PASSED" : "SOME FAILED");
|
||||
fflush(stdout);
|
||||
return total_fail == 0 ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user