[Kernel] fix types used in aqlm and ggml kernels to support dynamo (#7596)
This commit is contained in:
@@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
int8_t type, int64_t m, int64_t n) {
|
||||
int64_t type, int64_t m, int64_t n) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
@@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
@@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
torch::Tensor X, // input
|
||||
int8_t type, int64_t row) {
|
||||
int64_t type, int64_t row) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
int batch = X.sizes()[0];
|
||||
@@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
break;
|
||||
}
|
||||
return Y;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user