[TPU] kv cache update kernel supports dynamic grid (#20235)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-07-01 23:33:37 -07:00
committed by GitHub
parent b205e8467d
commit 7da296be04
4 changed files with 42 additions and 17 deletions

View File

@@ -111,6 +111,7 @@ class PallasMetadata:
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
num_kv_update_slices: torch.Tensor
num_slices_per_kv_cache_update_block: int
@@ -219,7 +220,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key, value, kv_cache, slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block)
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices)
output = torch.ops.xla.ragged_paged_attention(
query,
@@ -252,6 +254,7 @@ def write_to_kv_cache(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.
@@ -271,7 +274,7 @@ def write_to_kv_cache(
kv_cache = kv_cache.flatten(0, 1)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv, slot_mapping, kv_cache, page_size,
kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
num_slices_per_kv_cache_update_block)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)
@@ -279,32 +282,39 @@ def write_to_kv_cache(
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
new_kv_cache = xb.call_jax(
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
"int page_size, int num_slices_per_block) -> Tensor", )
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
page_size, num_slices_per_block)
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache