[Core] Use KVCacheBlock as much as possible instead of dict[block_id, KVCacheBlock] (#24830)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang
2025-09-23 15:11:14 -07:00
committed by GitHub
parent ae002924e9
commit 4f8c4b890a
3 changed files with 247 additions and 87 deletions

View File

@@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
get_block_hash, get_group_id,
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock, get_block_hash,
get_group_id,
get_request_block_hasher,
hash_block_tokens, init_none_hash,
make_block_hash_with_group_id)
@@ -138,7 +139,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
# Check full block metadata
parent_block_hash = None
@@ -171,7 +172,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@@ -207,7 +208,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([6], )
assert blocks is not None and blocks.get_block_ids() == ([6], )
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@@ -227,7 +228,9 @@ def test_prefill(hash_fn):
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
assert blocks is not None and blocks.get_block_ids() == ([
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
], )
assert free_block_queue.num_free_blocks == 0
assert (free_block_queue.fake_free_list_head.next_free_block
@@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
8], [9, 10, 11, 12])
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
5, 6, 7, 8
], [9, 10, 11, 12])
# Check full block metadata
parent_block_hash = None
@@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([13], [14], [15])
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
@@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
manager.free(req1)
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block)
manager.block_pool.cached_block_hash_to_block._cache)
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
def test_partial_request_hit(request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids,
block_size, sha256)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop(
manager.block_pool.cached_block_hash_to_block._cache.pop(
hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(req.block_hashes) == 3
@@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block[
manager.block_pool.cached_block_hash_to_block._cache[
hash_with_group_id] = cached_block_hash_to_block_bak[
hash_with_group_id]
manager.free(req)
@@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
# total cache miss.
# The cache hit length of full attention is 1 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
# Then it is cache miss as the two type of layers
# have different hit length.
test_partial_request_hit("8", [
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
@@ -406,7 +412,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
# Check full block metadata
@@ -441,7 +447,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@@ -478,6 +484,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks is not None
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
@@ -513,7 +520,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@@ -558,7 +565,8 @@ def test_evict():
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
# 5 full + 1 partial
assert blocks is not None and len(blocks.blocks[0]) == 6
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
@@ -570,7 +578,7 @@ def test_evict():
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 3 # 3 full blocks
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
last_token_id += 3 * 16
# 10 - (6 + 3) == 1
@@ -592,7 +600,7 @@ def test_evict():
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([10], )
assert blocks is not None and blocks.get_block_ids() == ([10], )
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
# Deallocate the block.
manager.free(req)
@@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert manager.block_pool.blocks[blocks.blocks[0]
[0].block_id].block_hash is None
@@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 1
# Allocate another block.
@@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
# Free the blocks.
@@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
@@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 3
assert blocks is not None and len(blocks.blocks[0]) == 3
# Free the blocks.
manager.free(req1)
@@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 4
assert blocks is not None and len(blocks.blocks[0]) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)), block_size, sha256)
@@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
assert len(block_pool.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])
# Test that blocks that don't start from the beginning are cached correctly.
# Test that blocks that don't start from the beginning are cached
# correctly.
blocks += [KVCacheBlock(block_id=2)]
block_pool.cache_full_blocks(
request=req,
@@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, sha256)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
@@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
# Manually add all blocks to cached_blocks
for block, block_hash in zip(pool.blocks, block_hashes):
block.block_hash = block_hash
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
pool.cached_block_hash_to_block.insert(block_hash, block)
block0, block1, block2, block3 = pool.blocks
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
block3.block_id: block3,
},
block_hash1: {
block1.block_id: block1
},
block_hash2: {
block2.block_id: block2
}
block_hash1: block1,
block_hash2: block2,
}
# Evict block1
pool._maybe_evict_cached_block(block1)
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
block_hash2: block2,
}
# Evict block0: block_hash0 entry should NOT be removed, as block3
# also use the same hash
pool._maybe_evict_cached_block(block0)
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
block_hash2: block2,
}
# Evict block2
pool._maybe_evict_cached_block(block2)
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
# Evict block3
pool._maybe_evict_cached_block(block3)
assert pool.cached_block_hash_to_block == {}
assert pool.cached_block_hash_to_block._cache == {}
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
@@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block, kv_cache_group_ids=[0]) is not None
manager.block_pool.cached_block_hash_to_block.pop(
manager.block_pool.cached_block_hash_to_block._cache.pop(
make_block_hash_with_group_id(block_hash_first_block, 0))
# New request
@@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
# there will be no matched prefix.
assert len(computed_blocks.blocks[0]) == 0
assert num_tokens == 0
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
key1 = BlockHashWithGroupId(b"hash1")
key2 = BlockHashWithGroupId(b"hash2")
block0 = KVCacheBlock(0)
block1 = KVCacheBlock(1)
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
# key0 inserted
cache.insert(key0, block0)
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
# key1 inserted
cache.insert(key1, block1)
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# No block poped due to block_id mismatch
assert cache.pop(key0, 100) is None
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# block poped with (key0, block ID 0)
assert cache.pop(key0, 0) is block0
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# No block poped due to block_id mismatch
assert cache.pop(key0, 1) is None
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# block poped with (key1, block ID 1)
assert cache.pop(key1, 1) is block1
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
def test_block_lookup_cache_multi_blocks_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
key1 = BlockHashWithGroupId(b"hash1")
block00 = KVCacheBlock(0)
block01 = KVCacheBlock(1)
block10 = KVCacheBlock(10)
block11 = KVCacheBlock(11)
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
cache.insert(key0, block00)
cache.insert(key0, block01)
cache.insert(key1, block10)
cache.insert(key1, block11)
assert cache.get_one_block(key0) is block00
assert cache.pop(key0, 0) is block00
assert cache.get_one_block(key0) is block01
assert cache.pop(key0, 1) is block01
assert cache.get_one_block(key0) is None
assert cache.pop(key0, 2) is None
assert cache.get_one_block(key1) is block10
assert cache.pop(key1, 10) is block10
assert cache.get_one_block(key1) is block11
assert cache.pop(key1, 11) is block11
assert cache.get_one_block(key1) is None
assert cache.pop(key1, 12) is None