[V1] Remove pre-allocation for KV cache (#16941)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-04-22 00:52:18 -07:00
committed by GitHub
parent 2689d5c027
commit c4ab9f3e71
5 changed files with 61 additions and 141 deletions

View File

@@ -8,7 +8,7 @@ import torch
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv, sha256
from vllm.utils import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
@@ -61,7 +61,6 @@ def test_prefill(hash_algo):
max_model_len=8192,
enable_caching=True,
caching_hash_algo=hash_algo,
num_preallocate_tokens=16,
)
# choose the hash function according to the parameter
@@ -80,7 +79,7 @@ def test_prefill(hash_algo):
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
# Check full block metadata
parent_block_hash = None
@@ -92,8 +91,8 @@ def test_prefill(hash_algo):
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (4, 5):
# Check partial block metadata
for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@@ -107,12 +106,12 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6, 7]
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3
# At this point, we should have 5 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
manager.free(req0)
manager.free(req1)
@@ -120,14 +119,14 @@ def test_prefill(hash_algo):
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (4)]
# [unique_req1 (5)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
@@ -139,29 +138,29 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [8, 9]
assert [b.block_id for b in blocks] == [6]
# Although we only have 5 free blocks, we have 8 blocks in
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 6
assert all([
b.ref_cnt == 0
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
])
assert len([
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]) == 5
]) == 6
manager.free(req2)
# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
req3 = make_request("3", [99] * (16 * 10))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1]
assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@@ -178,7 +177,6 @@ def test_prefill_plp():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
# the default hash function is hash
hash_fn = hash
@@ -197,7 +195,7 @@ def test_prefill_plp():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks]
# Check full block metadata
@@ -210,8 +208,8 @@ def test_prefill_plp():
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (4, 5):
# Check partial block metadata
for block_id in (4, ):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@@ -226,12 +224,12 @@ def test_prefill_plp():
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6, 7]
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3
# At this point, we should have 5 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
manager.free(req0)
manager.free(req1)
@@ -239,14 +237,14 @@ def test_prefill_plp():
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [unallocated (6, 7, 8, 9, 10)]
# [unique_req0 (4)]
# [unique_req1 (5)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
@@ -262,7 +260,7 @@ def test_prefill_plp():
block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4, 5]
assert block_ids != [1, 2, 3, 4]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
@@ -277,7 +275,6 @@ def test_decode():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
# Complete 3 blocks (48 tokens)
@@ -291,7 +288,7 @@ def test_decode():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@@ -299,28 +296,18 @@ def test_decode():
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
# Append slots without allocating a new block, but start using the
# preallocated block.
req0.num_computed_tokens = 59
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(5 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
req0.num_computed_tokens = 74
# 6 tokens to fill the previous block, and 10 tokens to fill
req0.num_computed_tokens = 59
# 9 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.allocate_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
assert new_blocks is not None and len(new_blocks) == 1
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
def test_evict():
@@ -328,7 +315,6 @@ def test_evict():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
last_token_id = 5 * 16 + 7
@@ -337,7 +323,7 @@ def test_evict():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated
assert len(blocks) == 6 # 5 full + 1 partial
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
@@ -349,7 +335,8 @@ def test_evict():
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
assert manager.block_pool.free_block_queue.num_free_blocks == 0
# 10 - (6 + 3) == 1
assert manager.block_pool.free_block_queue.num_free_blocks == 1
manager.free(req0)
manager.free(req1)
@@ -357,7 +344,7 @@ def test_evict():
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8]
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
@@ -365,8 +352,8 @@ def test_evict():
assert [b.block_id for b in computed_blocks] == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [7, 6]
assert manager.block_pool.free_block_queue.num_free_blocks == 6
assert [b.block_id for b in blocks] == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
def test_hash_block_correct_reuse():
@@ -379,7 +366,6 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate 1 block and cache it.
@@ -416,7 +402,6 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate a block and cache it.
@@ -465,7 +450,6 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
num_preallocate_tokens=0,
)
req1 = make_request("1", list(range(10))) # 2 blocks and some more
@@ -496,40 +480,6 @@ def test_basic_prefix_caching_disabled():
assert not blocks
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
"""
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
req = make_request("0", list(range(block_size * 30)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks
# Assume all computed, only when num_preallocate_tokens > 0, we need to
# consume the previously preallocated blocks.
if num_preallocated_blocks > 0:
manager.allocate_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)
# Append 1 block.
blocks = manager.allocate_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_cache_blocks(hash_fn):
"""
@@ -588,7 +538,6 @@ def test_mm_prefix_caching():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=16,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
@@ -626,7 +575,7 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@@ -667,7 +616,6 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
@@ -721,7 +669,6 @@ def test_reset_prefix_cache():
make_kv_cache_config(16, 11),
max_model_len=8192,
enable_caching=True,
num_preallocate_tokens=0,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]