Enable CUDA graph for GPTQ & SqueezeLLM (#2318)

This commit is contained in:
Woosuk Kwon
2024-01-03 09:52:29 -08:00
committed by GitHub
parent 9140561059
commit 6ef00b03a2
3 changed files with 15 additions and 13 deletions

View File

@@ -200,8 +200,10 @@ void squeezellm_gemm(
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
#else