Replace head_mapping params with num_kv_heads to attention kernel. (#1997)
Co-authored-by: wangguoya <wangguoya@baidu.com> Co-authored-by: Yang Zhao <zhaoyangstar@foxmail.com>
This commit is contained in:
@@ -54,9 +54,6 @@ class PagedAttention(nn.Module):
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.head_mapping = torch.repeat_interleave(
|
||||
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
self.num_queries_per_kv)
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
@@ -77,7 +74,7 @@ class PagedAttention(nn.Module):
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
@@ -172,7 +169,7 @@ class PagedAttention(nn.Module):
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.head_mapping,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
@@ -217,7 +214,7 @@ def _paged_attention(
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
head_mapping: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
@@ -244,7 +241,7 @@ def _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
@@ -274,7 +271,7 @@ def _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
|
||||
Reference in New Issue
Block a user