Optimize data movement (#20)

This commit is contained in:
Woosuk Kwon
2023-04-02 00:30:17 -07:00
committed by GitHub
parent 1f01a18d39
commit 897cb2ae28
17 changed files with 275 additions and 135 deletions

View File

@@ -1,7 +1,7 @@
import random
from typing import List, Optional
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
from cacheflow import attention_ops
@@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention(
num_blocks: int,
dtype: torch.dtype,
) -> None:
query = torch.randn(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
@@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention(
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
# Adjust the range of the values to reduce precision errors.
query = query / (head_size ** 0.5)
key_cache = key_cache / (head_size ** 0.5)
value_cache = value_cache / (head_size ** 0.5)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
@@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention(
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty_like(query)
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
attention_ops.single_query_cached_kv_attention(
output,
query,
@@ -175,19 +182,28 @@ def test_multi_query_kv_attention(
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
query = torch.randn(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
key = torch.rand_like(query)
value = torch.rand_like(query)
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
# Adjust the range of the values to reduce precision errors.
qkv = qkv / (head_size ** 0.5)
qkv = torch.stack([query, key, value], dim=1)
flash_attn = FlashAttention(softmax_scale=scale)
output = flash_attn(
qkv,
cu_seqlens=cu_seq_lens,
max_s=max_seq_len,
query, key, value = qkv.unbind(dim=1)
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
_flash_attn_forward(
query,
key,
value,
output,
cu_seq_lens,
cu_seq_lens,
max_seq_len,
max_seq_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True,
)[0]
return_softmax=False,
)
cu_seq_lens = cu_seq_lens.cpu().tolist()
ref_output = ref_multi_query_kv_attention(