From 7a8ba8eeb653b253f7617c51cd33a16ad9323f67 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 19:30:50 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20SMEM=20size=20calculation=20=E2=80=94=20?= =?UTF-8?q?TILE=5FSZ=20is=20in=20BF16=20elements,=20need=20*sizeof(bf16=5F?= =?UTF-8?q?t)=20for=20bytes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_qk_softmax.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_qk_softmax.cu b/tests/unit/test_qk_softmax.cu index aacd2605..ffc400b3 100644 --- a/tests/unit/test_qk_softmax.cu +++ b/tests/unit/test_qk_softmax.cu @@ -193,7 +193,7 @@ int main() { cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); - size_t smem = 4 + 128 + TILE_SZ + TILE_SZ + TILE_SZ + 16 + 8 + 128*4*2 + 256; + size_t smem = 4 + 128 + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + 16 + 8 + 128*sizeof(float)*2 + 256; cudaFuncSetAttribute(test_qk_softmax_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem); test_qk_softmax_kernel<<<1, 128, (int)smem>>>(d_out, d_q, d_tma_k, T, SK, SCALE); cudaError_t err = cudaDeviceSynchronize();