Fix CUDA kernel index data type in vllm/csrc/quantization/fused_kernels/layernorm_utils.cuh +10 (#15159)

Signed-off-by: Lu Fang <lufang@fb.com>
Co-authored-by: Richard Barnes <rbarnes@meta.com>
This commit is contained in:
Lu Fang
2025-03-20 19:01:11 -07:00
committed by GitHub
parent 0cfe7d386d
commit d3ccbd6350
10 changed files with 124 additions and 124 deletions

View File

@@ -199,12 +199,12 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
auto t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
@@ -337,12 +337,12 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
auto t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
@@ -458,12 +458,12 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
auto t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
@@ -586,12 +586,12 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
auto t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
@@ -765,14 +765,14 @@ __global__ void reconstruct_exllama_8bit_kernel(
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
auto t = threadIdx.x;
if (b_q_perm) {
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
@@ -862,14 +862,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
auto t = threadIdx.x;
if (b_q_perm) {
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
@@ -967,14 +967,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
auto t = threadIdx.x;
if (b_q_perm) {
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
@@ -1065,14 +1065,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
auto t = threadIdx.x;
if (b_q_perm) {
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
@@ -1181,11 +1181,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
int zero_width = width / 8;
int vec_height = height * 4;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
auto h = BLOCK_KN_SIZE * blockIdx.z / 8;
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
@@ -1197,8 +1197,8 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
}
__shared__ half2 deq2[256][8];
int val = threadIdx.x / 8;
int off = threadIdx.x % 8;
auto val = threadIdx.x / 8;
auto off = threadIdx.x % 8;
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
deq2[val][off] =
__halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
@@ -1280,11 +1280,11 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
int zero_width = width / 4;
int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
int h = BLOCK_KN_SIZE * blockIdx.z / 4;
auto h = BLOCK_KN_SIZE * blockIdx.z / 4;
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
@@ -1393,8 +1393,8 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
half* __restrict__ out) {
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 32 / bit;
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto row = blockIdx.y * 32 / bit;
if (column >= width) return;
// Views
@@ -1425,8 +1425,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
const int height, const int width, const int group,
half* __restrict__ out) {
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 32;
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto row = blockIdx.y * 32;
if (column >= width) return;
// Views
@@ -1542,7 +1542,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x;
auto n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
@@ -1555,7 +1555,7 @@ __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x;
auto n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
@@ -1568,7 +1568,7 @@ __global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x;
auto n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
@@ -1581,7 +1581,7 @@ __global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x;
auto n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
@@ -1599,9 +1599,9 @@ __global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w,
const uint64_t* w2 = (uint64_t*)w;
uint64_t* w_new2 = (uint64_t*)w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
auto w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 3;
uint64_t dst = 0;
@@ -1630,9 +1630,9 @@ __global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w,
const uint64_t* w2 = (uint64_t*)w;
uint64_t* w_new2 = (uint64_t*)w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
auto w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 4;
uint64_t dst = 0;
@@ -1658,10 +1658,10 @@ __global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width) {
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
auto w_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w_column >= w_width) return;
int w_new_row = blockIdx.y * 3;
int q_perm_idx = blockIdx.y << 5;
auto w_new_row = blockIdx.y * 3;
auto q_perm_idx = blockIdx.y << 5;
uint32_t dst[3] = {0, 0, 0};
#pragma unroll
@@ -1744,9 +1744,9 @@ __global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w,
const uint64_t* w2 = (uint64_t*)w;
uint64_t* w_new2 = (uint64_t*)w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
auto w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 2;
uint64_t dst = 0;