[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)
This commit is contained in:
@@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dequantize.cuh"
|
||||
@@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64)
|
||||
|
||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros, int split_k_iters, int thx,
|
||||
int thy) {
|
||||
torch::Tensor _zeros, int64_t split_k_iters,
|
||||
int64_t thx, int64_t thy) {
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
@@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
|
||||
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
int split_k_iters) {
|
||||
int64_t split_k_iters) {
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
Reference in New Issue
Block a user