diff --git a/tests/unit/test_fmha_6warp_tma.cu b/tests/unit/test_fmha_6warp_tma.cu index c68e5048..2f423a75 100644 --- a/tests/unit/test_fmha_6warp_tma.cu +++ b/tests/unit/test_fmha_6warp_tma.cu @@ -25,9 +25,9 @@ static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; m constexpr int HD = HD_VAL; constexpr int SK = 128; -constexpr int MMA_K_BF16 = 16; -constexpr int TILE_SZ = 128 * MMA_K_BF16; -constexpr int V_SUB_SZ = 16 * MMA_K_BF16; +constexpr int MY_MMA_K = 16; +constexpr int TILE_SZ = 128 * MY_MMA_K; +constexpr int V_SUB_SZ = 16 * MY_MMA_K; constexpr int TMEM_N = (HD <= 128) ? 128 : 256; #include "dsv4/kernels/attention/fmha_6warp_tma.cuh"