[Core][Optimization] change python dict to pytorch tensor (#4607)
This commit is contained in:
@@ -8,16 +8,16 @@ template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(
|
||||
std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
|
||||
const torch::Tensor& mapping_pairs,
|
||||
const int element_num_per_block, const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size();
|
||||
const size_t pair_num = mapping_pairs.size(0);
|
||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int layer = 0; layer < layer_num; ++layer) {
|
||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
|
||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||
int64_t target_offset =
|
||||
element_num_per_block * mapping_pairs[pair].second;
|
||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t *source_ptr = key_cache_ptr + source_offset;
|
||||
scalar_t *target_ptr = key_cache_ptr + target_offset;
|
||||
@@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl(
|
||||
|
||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
|
||||
torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
|
||||
mapping_pairs.reserve(block_mapping.size());
|
||||
for (const auto &pair : block_mapping) {
|
||||
for (const auto &dst : pair.second) {
|
||||
mapping_pairs.emplace_back(pair.first, dst);
|
||||
}
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user