[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

This commit is contained in:
bnellnm
2024-06-09 16:23:30 -04:00
committed by GitHub
parent 5d7e3d0176
commit 5467ac3196
55 changed files with 833 additions and 451 deletions

View File

@@ -5,8 +5,8 @@
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& mapping_pairs,
const int element_num_per_block,
const int layer_num) {
@@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
}
}; // namespace
void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
@@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, float kv_scale) {
const std::string& kv_cache_dtype, double kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0);