Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27)
* Add query stride to multi_query_cached_kv_attention * Add kernel benchmark script
This commit is contained in:
@@ -285,8 +285,9 @@ def test_multi_query_cached_kv_attention(
|
||||
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
||||
num_total_tokens = cu_query_lens[-1]
|
||||
|
||||
query = torch.randn(
|
||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(
|
||||
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
@@ -314,7 +315,8 @@ def test_multi_query_cached_kv_attention(
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty_like(query)
|
||||
output = torch.empty(
|
||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
|
||||
attention_ops.multi_query_cached_kv_attention(
|
||||
cu_query_lens,
|
||||
|
||||
Reference in New Issue
Block a user