[TPU] Optimize kv cache update kernel (#20415)

Signed-off-by: Yifei Teng <tengyifei88@gmail.com>
This commit is contained in:
Yifei Teng
2025-07-15 03:56:43 -07:00
committed by GitHub
parent 33d560001e
commit c586b55667
3 changed files with 63 additions and 16 deletions

View File

@@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize