fix: smem size in minimal QK test
This commit is contained in:
@@ -133,7 +133,7 @@ int main() {
|
||||
cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
int smem = 4 + 128 + TILE_SZ*2 + 4096;
|
||||
int smem = 4 + 128 + 128*16*2 + 128*16*2 + 4096;
|
||||
test_qk_minimal_kernel<<<1, 192, smem>>>(d_out, d_q, d_k, T, SK);
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
Reference in New Issue
Block a user