Merge EmbeddedLLM/vllm-rocm into vLLM main (#1836)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com> Co-authored-by: Amir Balwel <amoooori04@gmail.com> Co-authored-by: root <kuanfu.liu@akirakan.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: kuanfu <kuanfu.liu@embeddedllm.com> Co-authored-by: miloice <17350011+kliuae@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -28,8 +29,8 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel(
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
||||
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized(
|
||||
src_key_indices[j] = src_key_idx;
|
||||
src_value_indices[j] = src_value_idx;
|
||||
|
||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
||||
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
||||
Reference in New Issue
Block a user