[Kernel] fix types used in aqlm and ggml kernels to support dynamo (#7596)

This commit is contained in:
bnellnm
2024-08-16 17:00:11 -04:00
committed by GitHub
parent 7759ae958f
commit 37fd47e780
7 changed files with 39 additions and 53 deletions

View File

@@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
}
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
int8_t type, int64_t m, int64_t n) {
int64_t type, int64_t m, int64_t n) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto options =
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
@@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type, int64_t row) {
int64_t type, int64_t row) {
int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
@@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type, int64_t row) {
int64_t type, int64_t row) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
int batch = X.sizes()[0];
@@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
break;
}
return Y;
}
}