diff --git a/tests/unit/test_fmha_hd64_smem_p.cu b/tests/unit/test_fmha_hd64_smem_p.cu index 3b58b6c4..efed944d 100644 --- a/tests/unit/test_fmha_hd64_smem_p.cu +++ b/tests/unit/test_fmha_hd64_smem_p.cu @@ -220,6 +220,8 @@ int main() { // SMEM: tmem(4+12) + sQ(4*4096) + sK(4*4096) + sPk(4096) + sV(8*2048) + s_p_vals(512) + align int smem = (4+16 + NKT_QK*TILE_SZ*2 + NKT_QK*TILE_SZ*2 + TILE_SZ*2 + NKT_PV*V_TILE_SZ*2 + SK*4 + 256 + 127) & ~127; printf("SMEM: %d bytes (%.1f KB, limit 232 KB)\n", smem, smem/1024.0f); + // Must opt into >48KB shared memory on SM100 + cudaFuncSetAttribute(test_fmha_hd64_smem_p, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); test_fmha_hd64_smem_p<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE); cudaError_t launch_err = cudaGetLastError();