diff --git a/tests/unit/test_fmha_6warp_tma_multitile.cu b/tests/unit/test_fmha_6warp_tma_multitile.cu index 2d8ea96e..e3494c1f 100644 --- a/tests/unit/test_fmha_6warp_tma_multitile.cu +++ b/tests/unit/test_fmha_6warp_tma_multitile.cu @@ -75,7 +75,7 @@ int main() { int total_fail = 0; - for (int s_k : {128}) { + for (int s_k : {128, 256, 384, 512}) { printf("\n--- s_k=%d (%d KV tiles) ---\n", s_k, (s_k + 127) / 128); bf16_t* h_q = (bf16_t*)calloc(HD, sizeof(bf16_t));