[Bugfix] Add kv_scale input parameter to CPU backend (#3840)

This commit is contained in:
Woosuk Kwon
2024-04-03 21:33:08 -07:00
committed by GitHub
parent 537ee25f43
commit 498eb5cfa3
4 changed files with 12 additions and 5 deletions

View File

@@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) {
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) {
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)