test: full matrix for D1.5 multirow multitile

This commit is contained in:
2026-05-30 04:49:00 +00:00
parent 5544d3a0a4
commit f2544a4600

View File

@@ -205,18 +205,24 @@ int main() {
printf("START: test_fmha_6warp_tma_multirow_multitile HD=%d\n", HD);
int total_fail = 0;
// Just the most basic test first
total_fail += test_single(1, 128);
fflush(stdout);
printf("\n=== 6-warp TMA FMHA multi-row multi-tile HD=%d ===\n", HD);
if (total_fail == 0) {
total_fail += test_single(4, 128);
fflush(stdout);
total_fail += test_single(1, 256);
fflush(stdout);
// Single KV tile (s_k=128, baseline)
for (int T : {1, 4, 32, 128}) {
total_fail += test_single(T, 128);
}
// Multi-tile KV — the whole point of D1.5
for (int s_k : {256, 384, 512}) {
for (int T : {1, 4, 32, 128}) {
total_fail += test_single(T, s_k);
}
}
// Multi-head + batch
total_fail += test_single(4, 256, 4, 1); // n_h=4
total_fail += test_single(4, 256, 2, 2); // n_h=2, batch=2
printf("\nOverall: %s\n", total_fail == 0 ? "ALL PASSED" : "SOME FAILED");
fflush(stdout);
return total_fail == 0 ? 0 : 1;
}