Basic attention kernel that supports cached KV + (multi-)prompts (#24)
This commit is contained in:
committed by
GitHub
parent
897cb2ae28
commit
21b3671bbc
@@ -97,6 +97,61 @@ def ref_multi_query_kv_attention(
|
||||
return ref_output
|
||||
|
||||
|
||||
def ref_multi_query_cached_kv_attention(
|
||||
cu_query_lens: List[int],
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
|
||||
num_queries = len(cu_query_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_queries):
|
||||
start_idx = cu_query_lens[i]
|
||||
end_idx = cu_query_lens[i + 1]
|
||||
query_len = end_idx - start_idx
|
||||
context_len = int(context_lens[i])
|
||||
block_table = block_tables[i]
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
for j in range(context_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_heads, head_size)
|
||||
keys.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values.append(v)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
keys,
|
||||
values,
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
return ref_output
|
||||
|
||||
|
||||
def test_single_query_cached_kv_attention(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
@@ -216,6 +271,76 @@ def test_multi_query_kv_attention(
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_multi_query_cached_kv_attention(
|
||||
num_queries: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
query_lens = random.sample(range(1, MAX_SEQ_LEN), num_queries)
|
||||
cu_query_lens = [0]
|
||||
for query_len in query_lens:
|
||||
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
||||
num_total_tokens = cu_query_lens[-1]
|
||||
|
||||
query = torch.randn(
|
||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
|
||||
context_lens = [
|
||||
query_len + random.randint(0, MAX_SEQ_LEN - query_len)
|
||||
for query_len in query_lens
|
||||
]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_queries):
|
||||
block_table = [
|
||||
random.randint(0, num_blocks - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attention_ops.multi_query_cached_kv_attention(
|
||||
cu_query_lens,
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
)
|
||||
|
||||
ref_output = ref_multi_query_cached_kv_attention(
|
||||
cu_query_lens,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
dtype,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_attention(seed: int) -> None:
|
||||
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
|
||||
@@ -237,6 +362,24 @@ def test_attention(seed: int) -> None:
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# NOTE(siyuan): Same as above. Re-run the test if it fails. Also
|
||||
# note that the test is also more likely to fail due to the much
|
||||
# larger amount of tokens in the input may increase the variance.
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing multi_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
f'head_size={head_size}')
|
||||
test_multi_query_cached_kv_attention(
|
||||
num_queries=11,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
num_blocks=1024,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): FlashAttention does not support FP32.
|
||||
for dtype in [torch.half]:
|
||||
# NOTE(woosuk): FlashAttention does not support head_size > 128.
|
||||
|
||||
Reference in New Issue
Block a user