fix: smem size in minimal QK test

This commit is contained in:
2026-05-29 18:37:38 +00:00
parent ce89fe9170
commit 39aef1284f

View File

@@ -133,7 +133,7 @@ int main() {
cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = 4 + 128 + TILE_SZ*2 + 4096;
int smem = 4 + 128 + 128*16*2 + 128*16*2 + 4096;
test_qk_minimal_kernel<<<1, 192, smem>>>(d_out, d_q, d_k, T, SK);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }