[TPU] add kv cache update kernel (#19928)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -5,8 +5,12 @@ from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
# Required to register custom ops.
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
# Required to register custom ops.
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
@@ -107,6 +111,7 @@ class PallasMetadata:
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
@@ -212,7 +217,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
# Write input keys and values to the KV cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
||||
write_to_kv_cache(
|
||||
key, value, kv_cache, slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block)
|
||||
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
@@ -244,6 +251,7 @@ def write_to_kv_cache(
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
) -> None:
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
@@ -251,9 +259,9 @@ def write_to_kv_cache(
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
"""
|
||||
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
head_size = cdiv(head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||
@@ -262,4 +270,41 @@ def write_to_kv_cache(
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
kv_cache.index_copy_(0, slot_mapping, kv)
|
||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||
kv, slot_mapping, kv_cache, page_size,
|
||||
num_slices_per_kv_cache_update_block)
|
||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
|
||||
"page_size": page_size,
|
||||
"num_slices_per_block": num_slices_per_block
|
||||
})
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
|
||||
"int page_size, int num_slices_per_block) -> Tensor", )
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||
page_size, num_slices_per_block)
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
||||
Reference in New Issue
Block a user