fix: SMEM size calculation — TILE_SZ is in BF16 elements, need *sizeof(bf16_t) for bytes

This commit is contained in:
2026-05-29 19:30:50 +00:00
parent aac1b25442
commit 7a8ba8eeb6

View File

@@ -193,7 +193,7 @@ int main() {
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
size_t smem = 4 + 128 + TILE_SZ + TILE_SZ + TILE_SZ + 16 + 8 + 128*4*2 + 256;
size_t smem = 4 + 128 + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + 16 + 8 + 128*sizeof(float)*2 + 256;
cudaFuncSetAttribute(test_qk_softmax_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem);
test_qk_softmax_kernel<<<1, 128, (int)smem>>>(d_out, d_q, d_tma_k, T, SK, SCALE);
cudaError_t err = cudaDeviceSynchronize();