[TPU][Bugfix] fix kv cache padding (#20048)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -48,13 +48,7 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
padded_head_size = cdiv(
|
||||
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
num_blocks = num_blocks * head_size // padded_head_size
|
||||
if padded_head_size != head_size:
|
||||
logger.warning_once(
|
||||
"head size is padded to %d, and num_blocks is adjusted to %d"
|
||||
" accordingly", padded_head_size, num_blocks)
|
||||
head_size = padded_head_size
|
||||
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
|
||||
Reference in New Issue
Block a user