diff --git a/tests/unit/test_fmha_6warp_tma.cu b/tests/unit/test_fmha_6warp_tma.cu index 7a62fc7c..3631c765 100644 --- a/tests/unit/test_fmha_6warp_tma.cu +++ b/tests/unit/test_fmha_6warp_tma.cu @@ -28,7 +28,7 @@ constexpr int SK = 128; constexpr int MY_MMA_K = 16; constexpr int TILE_SZ = 128 * MY_MMA_K; constexpr int V_SUB_SZ = 16 * MY_MMA_K; -constexpr int TMEM_N = (HD <= 128) ? 128 : 256; +constexpr int TMEM_N = (HD <= 128) ? 128 : (HD <= 256) ? 256 : 512; #include "dsv4/kernels/attention/fmha_6warp_tma.cuh"