[TPU] fix kv_cache_update kernel block size choosing logic (#21007)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-07-15 21:39:48 -07:00
committed by GitHub
parent c11013db8b
commit 85431bd9ad
2 changed files with 51 additions and 3 deletions

View File

@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
return kv_cache
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
if dtype.is_floating_point:
try:
return torch.finfo(dtype).bits
except TypeError:
pass
elif dtype.is_complex:
if dtype is torch.complex32:
return 32
elif dtype is torch.complex64:
return 64
elif dtype is torch.complex128:
return 128
else:
try:
return torch.iinfo(dtype).bits
# torch.iinfo cannot support int4, int2, bits8...
except TypeError:
pass
str_dtype = str(dtype)
# support torch.int4, torch.int5, torch.uint5...
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
return int(str_dtype[-1])
raise TypeError(f"Getting the bit width of {dtype} is not supported")
def get_dtype_packing(dtype):
bits = dtype_bits(dtype)
if 32 % bits != 0:
raise ValueError(
f"The bit width must be divisible by 32, but got bits={bits}, "
"dtype={dtype}")
return 32 // bits
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
padded_head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
num_combined_kv_heads = num_kv_heads * 2
# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (block_size * num_combined_kv_heads * padded_head_size *
kv_cache_dtype_bits // 8)