[TPU] support attention head dim smaller than 128 (#19620)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Chengji Yao
2025-06-15 23:40:53 -07:00
committed by GitHub
parent b692e9cd07
commit a77aea59fd
2 changed files with 65 additions and 7 deletions

View File

@@ -17,6 +17,9 @@ from vllm.utils import cdiv, next_power_of_2
logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128
class PallasAttentionBackend(AttentionBackend):
@@ -43,6 +46,14 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
padded_head_size = cdiv(
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
num_blocks = num_blocks * head_size // padded_head_size
if padded_head_size != head_size:
logger.warning_once(
"head size is padded to %d, and num_blocks is adjusted to %d"
" accordingly", padded_head_size, num_blocks)
head_size = padded_head_size
return (num_blocks, block_size, num_kv_heads * 2, head_size)
@staticmethod
@@ -132,8 +143,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
@@ -187,6 +196,18 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
padded_head_size = cdiv(
self.head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
query = torch.nn.functional.pad(
query, (0, padded_head_size - self.head_size), value=0.0)
key = torch.nn.functional.pad(
key, (0, padded_head_size - self.head_size), value=0.0)
value = torch.nn.functional.pad(
value, (0, padded_head_size - self.head_size), value=0.0)
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
@@ -213,6 +234,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
soft_cap=self.logits_soft_cap,
)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
output = output[:, :, :self.head_size]
return output.reshape(num_tokens, hidden_size)
@@ -231,11 +255,8 @@ def write_to_kv_cache(
"""
_, _, num_combined_kv_heads, head_size = kv_cache.shape
num_kv_heads = num_combined_kv_heads // 2
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)