[TPU] Optimize kv cache update kernel (#20415)
Signed-off-by: Yifei Teng <tengyifei88@gmail.com>
This commit is contained in:
@@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
||||
|
||||
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
|
||||
kv_cache_dtype: torch.dtype) -> int:
|
||||
"""Returns the size in bytes of one page of the KV cache."""
|
||||
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
|
||||
|
||||
Reference in New Issue
Block a user