diff --git a/tests/unit/test_fmha_6warp_multirow.cu b/tests/unit/test_fmha_6warp_multirow.cu index 8554717a..2abcf2f6 100644 --- a/tests/unit/test_fmha_6warp_multirow.cu +++ b/tests/unit/test_fmha_6warp_multirow.cu @@ -197,6 +197,8 @@ int main() { cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice); cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice); + constexpr int TILE_SZ = 128 * MMA_K_BF16; + constexpr int V_SUB_SZ = 16 * MMA_K_BF16; size_t smem_off = 8 + 128*4 + 128*4; smem_off = ((smem_off + 127) & ~(size_t)127); smem_off += TILE_SZ + TILE_SZ;