[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)
This commit is contained in:
@@ -5,8 +5,8 @@
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& mapping_pairs,
|
||||
const int element_num_per_block,
|
||||
const int layer_num) {
|
||||
@@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl(
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
unsigned num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
@@ -104,7 +107,7 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, float kv_scale) {
|
||||
const std::string& kv_cache_dtype, double kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
|
||||
int num_tokens = key.size(0);
|
||||
|
||||
Reference in New Issue
Block a user