[FIX] Support non-zero CUDA devices in custom kernels (#1959)

This commit is contained in:
Jee Li
2024-01-03 11:09:59 +08:00
committed by GitHub
parent 4934d49274
commit 77af974b40
12 changed files with 74 additions and 30 deletions

View File

@@ -7,6 +7,7 @@
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
@@ -199,7 +200,7 @@ void squeezellm_gemm(
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),