Replace head_mapping params with num_kv_heads to attention kernel. (#1997)

Co-authored-by: wangguoya <wangguoya@baidu.com>
Co-authored-by: Yang Zhao <zhaoyangstar@foxmail.com>
This commit is contained in:
wbn
2023-12-11 02:12:53 +08:00
committed by GitHub
parent 24cde76a15
commit dacaf5a400
5 changed files with 26 additions and 37 deletions

View File

@@ -54,9 +54,6 @@ class PagedAttention(nn.Module):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
@@ -77,7 +74,7 @@ class PagedAttention(nn.Module):
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
@@ -172,7 +169,7 @@ class PagedAttention(nn.Module):
key_cache,
value_cache,
input_metadata,
self.head_mapping,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
@@ -217,7 +214,7 @@ def _paged_attention(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
head_mapping: torch.Tensor,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
@@ -244,7 +241,7 @@ def _paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
@@ -274,7 +271,7 @@ def _paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,