Replace head_mapping params with num_kv_heads to attention kernel. (#1997)

Co-authored-by: wangguoya <wangguoya@baidu.com>
Co-authored-by: Yang Zhao <zhaoyangstar@foxmail.com>
This commit is contained in:
wbn
2023-12-11 02:12:53 +08:00
committed by GitHub
parent 24cde76a15
commit dacaf5a400
5 changed files with 26 additions and 37 deletions

View File

@@ -131,9 +131,6 @@ def test_paged_attention(
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
@@ -170,7 +167,7 @@ def test_paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,
@@ -202,7 +199,7 @@ def test_paged_attention(
query,
key_cache,
value_cache,
head_mapping,
num_kv_heads,
scale,
block_tables,
context_lens,