Allocate kv_cache with stride order (#16605)

Signed-off-by: shuw <shuw@nvidia.com>
This commit is contained in:
Shu Wang
2025-04-26 00:03:31 -05:00
committed by GitHub
parent b278911229
commit 9e96f56efb
6 changed files with 119 additions and 50 deletions

View File

@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const float* k_scale, const float* v_scale) {
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
block_offset * page_stride +
head_idx * head_stride + head_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
@@ -396,16 +397,16 @@ void reshape_and_cache(
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_flash(
@@ -432,9 +433,11 @@ void reshape_and_cache_flash(
int head_size = key.size(2);
int block_size = key_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = key_cache.stride(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int64_t block_stride = key_cache.stride(0);
int64_t page_stride = key_cache.stride(1);
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
dim3 grid(num_tokens);