diff --git a/tests/unit/test_fmha_6warp_multihead.cu b/tests/unit/test_fmha_6warp_multihead.cu index 82dd0e97..a13de359 100644 --- a/tests/unit/test_fmha_6warp_multihead.cu +++ b/tests/unit/test_fmha_6warp_multihead.cu @@ -99,7 +99,6 @@ static int test_mha(int n_h) { params.s_k = SK; params.scale = SCALE; params.head_dim = HD; - params.n_kv_tiles = 0; // auto-calculate params.q_head_stride = HD; // T=1, stride = 1 * hd params.q_batch_stride = n_h * HD; params.k_head_stride = SK * HD; // each head has its own K @@ -219,7 +218,6 @@ static int test_mqa(int n_q, int n_kv) { params.s_k = SK; params.scale = SCALE; params.head_dim = HD; - params.n_kv_tiles = 0; // auto-calculate params.q_head_stride = HD; params.q_batch_stride = n_q * HD; params.k_head_stride = 0; // MQA: all heads share same K @@ -314,7 +312,6 @@ static int test_batched(int n_h, int batch_size) { params.s_k = SK; params.scale = SCALE; params.head_dim = HD; - params.n_kv_tiles = 0; // auto-calculate params.q_head_stride = HD; params.q_batch_stride = n_h * HD; params.k_head_stride = SK * HD;