diff --git a/tests/unit/test_qk_softmax.cu b/tests/unit/test_qk_softmax.cu index aacd2605..ffc400b3 100644 --- a/tests/unit/test_qk_softmax.cu +++ b/tests/unit/test_qk_softmax.cu @@ -193,7 +193,7 @@ int main() { cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); - size_t smem = 4 + 128 + TILE_SZ + TILE_SZ + TILE_SZ + 16 + 8 + 128*4*2 + 256; + size_t smem = 4 + 128 + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + 16 + 8 + 128*sizeof(float)*2 + 256; cudaFuncSetAttribute(test_qk_softmax_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem); test_qk_softmax_kernel<<<1, 128, (int)smem>>>(d_out, d_q, d_tma_k, T, SK, SCALE); cudaError_t err = cudaDeviceSynchronize();