[Kernel] Add more dtype support for GGUF kernels (#14043)

Signed-off-by: SzymonOzog <szymon.ozog@aleph-alpha.com>
Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
This commit is contained in:
Szymon Ożóg
2025-03-10 15:30:04 +01:00
committed by GitHub
parent b0746fae3d
commit 89cdaa83e7
6 changed files with 318 additions and 266 deletions

View File

@@ -1,8 +1,8 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
template <typename scalar_t, int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void mul_mat_q(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const block_q_t * x = (const block_q_t *) vx;
@@ -38,7 +38,7 @@ static __device__ __forceinline__ void mul_mat_q(
threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
#pragma unroll
for (int ir = 0; ir < qr; ++ir) {
for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) {
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
const int kbxd = kqs / QI8_1;
@@ -98,7 +98,7 @@ static __device__ __forceinline__ void mul_mat_q(
if (row_dst >= nrows_dst) {
continue;
}
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]);
dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE_GGUF][j/nwarps];
}
}
}
@@ -113,24 +113,25 @@ static __device__ __forceinline__ void mul_mat_q(
#define NWARPS_Q4_0 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2)
#endif
mul_mat_q4_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q4_0;
const int mmq_y = MMQ_Y_Q4_0;
const int nwarps = NWARPS_Q4_0;
mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
mul_mat_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_0;
@@ -144,11 +145,11 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -163,24 +164,25 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
#define NWARPS_Q4_1 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2)
#endif
mul_mat_q4_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q4_1;
const int mmq_y = MMQ_Y_Q4_1;
const int nwarps = NWARPS_Q4_1;
mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
mul_mat_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q4_1_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_1;
@@ -194,11 +196,11 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -213,24 +215,25 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
#define NWARPS_Q5_0 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2)
#endif
mul_mat_q5_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0;
const int nwarps = NWARPS_Q5_0;
mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
mul_mat_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q5_0_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_0;
@@ -244,11 +247,11 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -263,24 +266,25 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
#define NWARPS_Q5_1 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2)
#endif
mul_mat_q5_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1;
const int nwarps = NWARPS_Q5_1;
mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
mul_mat_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q5_1_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1;
@@ -293,11 +297,11 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -312,24 +316,25 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
#define NWARPS_Q8_0 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2)
#endif
mul_mat_q8_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0;
const int nwarps = NWARPS_Q8_0;
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
mul_mat_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q8_0_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0;
@@ -342,11 +347,11 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -361,24 +366,25 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
#define NWARPS_Q2_K 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2)
#endif
mul_mat_q2_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K;
const int nwarps = NWARPS_Q2_K;
mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
mul_mat_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q2_K_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K;
@@ -391,11 +397,11 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -410,25 +416,26 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
#define NWARPS_Q3_K 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2)
#endif
mul_mat_q3_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K;
const int nwarps = NWARPS_Q3_K;
mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
mul_mat_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q3_K_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q3_K;
@@ -442,11 +449,11 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -461,24 +468,25 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
#define NWARPS_Q4_K 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2)
#endif
mul_mat_q4_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K;
const int nwarps = NWARPS_Q4_K;
mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
mul_mat_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q4_K_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K;
@@ -491,11 +499,11 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -510,24 +518,25 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
#define NWARPS_Q5_K 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2)
#endif
mul_mat_q5_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K;
const int nwarps = NWARPS_Q5_K;
mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
mul_mat_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q5_K_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_K;
@@ -541,11 +550,11 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
@@ -560,24 +569,25 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
#define NWARPS_Q6_K 4
#endif
template <bool need_check> static __global__ void
template<typename scalar_t, bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2)
#endif
mul_mat_q6_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K;
const int nwarps = NWARPS_Q6_K;
mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
mul_mat_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template<typename scalar_t>
static void ggml_mul_mat_q6_K_q8_1_cuda(
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K;
@@ -590,11 +600,11 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}