[Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue (#14310)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
# These are the 2 tunable parameters of the paged attention Pallas kernel.
|
||||
NUM_QUERIES_PER_BLOCK = 16
|
||||
NUM_QUERIES_PER_BLOCK = 32
|
||||
NUM_KV_PAGES_PER_BLOCK = 128
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@@ -115,6 +115,17 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
tpu_version = torch_xla.tpu.version()
|
||||
if tpu_version < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
|
||||
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
|
||||
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
|
||||
if tpu_version == 4:
|
||||
self.vmem_limit_bytes = 16 * 1024 * 1024
|
||||
else:
|
||||
self.vmem_limit_bytes = 64 * 1024 * 1024
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@@ -131,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
|
||||
[num_kv_heads, num_blocks, block_size, head_size])
|
||||
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size],
|
||||
[num_blocks, block_size, num_kv_heads, head_size])
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@@ -154,10 +165,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
# use_kernel switches between using kernel or reference implementation
|
||||
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
|
||||
use_kernel = False
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
@@ -168,8 +175,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata.num_seqs,
|
||||
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
|
||||
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
|
||||
use_kernel=use_kernel,
|
||||
)
|
||||
vmem_limit_bytes=self.vmem_limit_bytes,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale)
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
@@ -186,16 +194,15 @@ def write_to_kv_cache(
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
k_cache = [num_blocks, block_size, num_kv_heads, head_size]
|
||||
v_cache = [num_blocks, block_size, num_kv_heads, head_size]
|
||||
|
||||
"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||
|
||||
key = key.flatten(0, 1)
|
||||
value = value.flatten(0, 1)
|
||||
key_cache = key_cache.flatten(0, 2)
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache = key_cache.flatten(0, 1)
|
||||
value_cache = value_cache.flatten(0, 1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
||||
Reference in New Issue
Block a user