fix: SMEM size calculation — TILE_SZ is in BF16 elements, need *sizeof(bf16_t) for bytes
This commit is contained in:
@@ -193,7 +193,7 @@ int main() {
|
||||
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
||||
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
size_t smem = 4 + 128 + TILE_SZ + TILE_SZ + TILE_SZ + 16 + 8 + 128*4*2 + 256;
|
||||
size_t smem = 4 + 128 + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + 16 + 8 + 128*sizeof(float)*2 + 256;
|
||||
cudaFuncSetAttribute(test_qk_softmax_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem);
|
||||
test_qk_softmax_kernel<<<1, 128, (int)smem>>>(d_out, d_q, d_tma_k, T, SK, SCALE);
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
|
||||
Reference in New Issue
Block a user