[Kernel] fix types used in aqlm and ggml kernels to support dynamo (#7596)
This commit is contained in:
@@ -487,7 +487,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
|
||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) {
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||
|
||||
@@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
int8_t type, int64_t m, int64_t n) {
|
||||
int64_t type, int64_t m, int64_t n) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
@@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
@@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
int batch = X.sizes()[0];
|
||||
@@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
break;
|
||||
}
|
||||
return Y;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user