[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)
This commit is contained in:
@@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) {
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
const half2* __restrict__ vec,
|
||||
const half2* __restrict__ vec,
|
||||
#else
|
||||
const __half2* __restrict__ vec,
|
||||
const __half2* __restrict__ vec,
|
||||
#endif
|
||||
const int* __restrict__ mat,
|
||||
const int* __restrict__ mat,
|
||||
#ifndef USE_ROCM
|
||||
half2* __restrict__ mul,
|
||||
half2* __restrict__ mul,
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
int batch,
|
||||
int vec_height
|
||||
) {
|
||||
const __half* __restrict__ lookup_table, int height, int width, int batch,
|
||||
int vec_height) {
|
||||
|
||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
@@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel(
|
||||
unsigned int tmp1;
|
||||
unsigned int lut_index1, lut_index2;
|
||||
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
i = width * row + col;
|
||||
res = __int2half_rd(0);
|
||||
k = 0;
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x < blockwidth2)
|
||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||
blockvec[threadIdx.x] =
|
||||
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
|
||||
threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
while (k < blockwidth2) {
|
||||
@@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
|
||||
res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
@@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace squeezellm
|
||||
} // namespace vllm
|
||||
} // namespace squeezellm
|
||||
} // namespace vllm
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table
|
||||
) {
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table) {
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||
);
|
||||
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*) vec.data<at::Half>(),
|
||||
(half2*)vec.data<at::Half>(),
|
||||
#else
|
||||
(__half2*) vec.data_ptr<at::Half>(),
|
||||
(__half2*)vec.data_ptr<at::Half>(),
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
|
||||
#else
|
||||
(float2*) mul.data_ptr<float>(),
|
||||
(__half*) lookup_table.data_ptr<at::Half>(),
|
||||
(float2*)mul.data_ptr<float>(),
|
||||
(__half*)lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
height, width, batch, vec_height);
|
||||
}
|
||||
|
||||
#undef BLOCKWIDTH
|
||||
|
||||
Reference in New Issue
Block a user