Implement PagedAttention V2 (#1348)

This commit is contained in:
Woosuk Kwon
2023-10-16 00:59:57 -07:00
committed by GitHub
parent 29678cd213
commit 928de46888
6 changed files with 764 additions and 139 deletions

View File

@@ -14,13 +14,14 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 128 # Arbitrary values for testing
PARTITION_SIZE = 512
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
SEEDS = [0]
@@ -96,6 +97,7 @@ def ref_single_query_cached_kv_attention(
output[i].copy_(out, non_blocking=True)
@pytest.mark.parametrize("version", ["v1", "v2"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@@ -103,9 +105,9 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_single_query_cached_kv_attention(
def test_paged_attention(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
@@ -162,19 +164,54 @@ def test_single_query_cached_kv_attention(
# Call the paged attention kernel.
output = torch.empty_like(query)
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
if version == "v1":
attention_ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
else:
assert False, f"Unknown version: {version}"
# Run the reference implementation.
ref_output = torch.empty_like(query)