[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin
2024-05-22 03:18:41 -04:00
committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
64 changed files with 6398 additions and 6790 deletions

View File

@@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) {
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
const half2* __restrict__ vec,
const half2* __restrict__ vec,
#else
const __half2* __restrict__ vec,
const __half2* __restrict__ vec,
#endif
const int* __restrict__ mat,
const int* __restrict__ mat,
#ifndef USE_ROCM
half2* __restrict__ mul,
half2* __restrict__ mul,
#else
float2* __restrict__ mul,
float2* __restrict__ mul,
#endif
const __half* __restrict__ lookup_table,
int height,
int width,
int batch,
int vec_height
) {
const __half* __restrict__ lookup_table, int height, int width, int batch,
int vec_height) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2];
@@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel(
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b){
for (int b = 0; b < batch; ++b) {
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
blockvec[threadIdx.x] =
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
@@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res);
#else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
res);
#endif
i += width;
@@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel(
}
}
} // namespace squeezellm
} // namespace vllm
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
(half2*)vec.data<at::Half>(),
#else
(__half2*) vec.data_ptr<at::Half>(),
(__half2*)vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(),
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
#else
(float2*) mul.data_ptr<float>(),
(__half*) lookup_table.data_ptr<at::Half>(),
(float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#endif
height, width, batch, vec_height
);
height, width, batch, vec_height);
}
#undef BLOCKWIDTH