Implement block copy kernel to optimize beam search (#32)

This commit is contained in:
Woosuk Kwon
2023-04-07 17:45:07 -07:00
committed by GitHub
parent a490aafa36
commit 0f40557af6
6 changed files with 154 additions and 48 deletions

View File

@@ -5,6 +5,61 @@ import torch
from cacheflow import cache_ops
def test_copy_blocks(
num_mappings: int,
num_layers: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
# Generate random block mappings.
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, num_mappings)
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
# Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.randn(
size=key_cache_shape, dtype=dtype, device='cuda')
key_caches.append(key_cache)
cloned_key_caches = []
for key_cache in key_caches:
cloned_key_caches.append(key_cache.clone())
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
value_caches.append(value_cache)
cloned_value_caches = []
for value_cache in value_caches:
cloned_value_caches.append(value_cache.clone())
# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
# Reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache)
def test_reshape_and_cache(
num_tokens: int,
num_heads: int,
@@ -46,6 +101,9 @@ def test_reshape_and_cache(
@torch.inference_mode()
def test_cache() -> None:
test_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=torch.half)
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half)