[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

This commit is contained in:
bnellnm
2024-06-09 16:23:30 -04:00
committed by GitHub
parent 5d7e3d0176
commit 5467ac3196
55 changed files with 833 additions and 451 deletions

View File

@@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh"
@@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64)
torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int split_k_iters, int thx,
int thy) {
torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy) {
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
@@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
int split_k_iters) {
int64_t split_k_iters) {
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));