diff --git a/tests/unit/test_fmha_6warp_tma_multirow_multitile.cu b/tests/unit/test_fmha_6warp_tma_multirow_multitile.cu index 20b6a6ec..3183db8f 100644 --- a/tests/unit/test_fmha_6warp_tma_multirow_multitile.cu +++ b/tests/unit/test_fmha_6warp_tma_multirow_multitile.cu @@ -205,18 +205,24 @@ int main() { printf("START: test_fmha_6warp_tma_multirow_multitile HD=%d\n", HD); int total_fail = 0; - // Just the most basic test first - total_fail += test_single(1, 128); - fflush(stdout); + printf("\n=== 6-warp TMA FMHA multi-row multi-tile HD=%d ===\n", HD); - if (total_fail == 0) { - total_fail += test_single(4, 128); - fflush(stdout); - total_fail += test_single(1, 256); - fflush(stdout); + // Single KV tile (s_k=128, baseline) + for (int T : {1, 4, 32, 128}) { + total_fail += test_single(T, 128); } + // Multi-tile KV — the whole point of D1.5 + for (int s_k : {256, 384, 512}) { + for (int T : {1, 4, 32, 128}) { + total_fail += test_single(T, s_k); + } + } + + // Multi-head + batch + total_fail += test_single(4, 256, 4, 1); // n_h=4 + total_fail += test_single(4, 256, 2, 2); // n_h=2, batch=2 + printf("\nOverall: %s\n", total_fail == 0 ? "ALL PASSED" : "SOME FAILED"); - fflush(stdout); return total_fail == 0 ? 0 : 1; }