[Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue (#14310)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-03-06 15:31:05 -08:00
committed by GitHub
parent 04222984f8
commit 0578e5a462
3 changed files with 58 additions and 66 deletions

View File

@@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 16
NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 128
@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
@@ -115,6 +115,17 @@ class PallasAttentionBackendImpl(AttentionImpl):
"are not implemented for "
"PallasAttentionBackendImpl")
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if tpu_version == 4:
self.vmem_limit_bytes = 16 * 1024 * 1024
else:
self.vmem_limit_bytes = 64 * 1024 * 1024
def forward(
self,
layer: AttentionLayer,
@@ -131,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
[num_kv_heads, num_blocks, block_size, head_size])
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size],
[num_blocks, block_size, num_kv_heads, head_size])
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
@@ -154,10 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
# use_kernel switches between using kernel or reference implementation
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
use_kernel = False
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
@@ -168,8 +175,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
use_kernel=use_kernel,
)
vmem_limit_bytes=self.vmem_limit_bytes,
use_kernel=True,
sm_scale=self.scale)
return output.reshape(num_tokens, hidden_size)
@@ -186,16 +194,15 @@ def write_to_kv_cache(
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
k_cache = [num_blocks, block_size, num_kv_heads, head_size]
v_cache = [num_blocks, block_size, num_kv_heads, head_size]
"""
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 1)
value = value.flatten(0, 1)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache = key_cache.flatten(0, 1)
value_cache = value_cache.flatten(0, 1)
slot_mapping = slot_mapping.flatten()
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)