[Kernel][CPU] CPU MLA (#14744)
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
@@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl(
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void concat_and_cache_mla_cpu_impl(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int num_tokens, //
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size //
|
||||
) {
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src,
|
||||
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
||||
int size, int offset) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
@@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
});
|
||||
}
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
TORCH_CHECK(kv_cache_dtype != "fp8");
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
||||
concat_and_cache_mla_cpu_impl<scalar_t>(
|
||||
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
||||
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
||||
kv_lora_rank, pe_dim, block_size);
|
||||
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
|
||||
Reference in New Issue
Block a user