[TPU] support attention head dim smaller than 128 (#19620)
Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -17,6 +17,9 @@ from vllm.utils import cdiv, next_power_of_2
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TPU requires the head size to be a multiple of 128.
|
||||
TPU_HEAD_SIZE_ALIGNMENT = 128
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@@ -43,6 +46,14 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
padded_head_size = cdiv(
|
||||
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
num_blocks = num_blocks * head_size // padded_head_size
|
||||
if padded_head_size != head_size:
|
||||
logger.warning_once(
|
||||
"head size is padded to %d, and num_blocks is adjusted to %d"
|
||||
" accordingly", padded_head_size, num_blocks)
|
||||
head_size = padded_head_size
|
||||
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
|
||||
@staticmethod
|
||||
@@ -132,8 +143,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
@@ -187,6 +196,18 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
padded_head_size = cdiv(
|
||||
self.head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, padded_head_size - self.head_size), value=0.0)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, padded_head_size - self.head_size), value=0.0)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, padded_head_size - self.head_size), value=0.0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||
# Write input keys and values to the KV cache.
|
||||
@@ -213,6 +234,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
output = output[:, :, :self.head_size]
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@@ -231,11 +255,8 @@ def write_to_kv_cache(
|
||||
|
||||
"""
|
||||
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
num_kv_heads = num_combined_kv_heads // 2
|
||||
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
head_size = cdiv(head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||
head_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user