Fix CUDA kernel index data type in vllm/csrc/quantization/fused_kernels/layernorm_utils.cuh +10 (#15159)

Signed-off-by: Lu Fang <lufang@fb.com>
Co-authored-by: Richard Barnes <rbarnes@meta.com>
This commit is contained in:
Lu Fang
2025-03-20 19:01:11 -07:00
committed by GitHub
parent 0cfe7d386d
commit d3ccbd6350
10 changed files with 124 additions and 124 deletions

View File

@@ -1,7 +1,7 @@
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const auto row = blockIdx.x*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
@@ -16,7 +16,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx