[V1][TPU] Change kv cache shape. (#15145)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
iefgnoix
2025-03-19 12:16:42 -07:00
committed by GitHub
parent 8310e0b59b
commit b0e96aaebb
2 changed files with 13 additions and 16 deletions

View File

@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
return (num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def swap_blocks(
@@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size],
[num_blocks, block_size, num_kv_heads, head_size])
kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
[num_blocks, block_size, num_kv_heads * head_size])
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
key_cache, value_cache = kv_cache
if kv_cache[0].numel() > 0:
@@ -192,10 +190,10 @@ def write_to_kv_cache(
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
k_cache = [num_blocks, block_size, num_kv_heads, head_size]
v_cache = [num_blocks, block_size, num_kv_heads, head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
k_cache = [num_blocks, block_size, num_kv_heads * head_size]
v_cache = [num_blocks, block_size, num_kv_heads * head_size]
"""
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
@@ -203,6 +201,5 @@ def write_to_kv_cache(
key_cache = key_cache.flatten(0, 1)
value_cache = value_cache.flatten(0, 1)
slot_mapping = slot_mapping.flatten()
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)