[Core] Use KVCacheBlock as much as possible instead of dict[block_id, KVCacheBlock] (#24830)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
@@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
get_block_hash, get_group_id,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock, get_block_hash,
|
||||
get_group_id,
|
||||
get_request_block_hasher,
|
||||
hash_block_tokens, init_none_hash,
|
||||
make_block_hash_with_group_id)
|
||||
@@ -138,7 +139,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@@ -171,7 +172,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@@ -207,7 +208,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([6], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([6], )
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@@ -227,7 +228,9 @@ def test_prefill(hash_fn):
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([
|
||||
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
|
||||
], )
|
||||
|
||||
assert free_block_queue.num_free_blocks == 0
|
||||
assert (free_block_queue.fake_free_list_head.next_free_block
|
||||
@@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
||||
8], [9, 10, 11, 12])
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
|
||||
5, 6, 7, 8
|
||||
], [9, 10, 11, 12])
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([13], [14], [15])
|
||||
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
@@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
|
||||
manager.free(req1)
|
||||
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block)
|
||||
manager.block_pool.cached_block_hash_to_block._cache)
|
||||
|
||||
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||
block_size, sha256)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||
hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(req.block_hashes) == 3
|
||||
@@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block[
|
||||
manager.block_pool.cached_block_hash_to_block._cache[
|
||||
hash_with_group_id] = cached_block_hash_to_block_bak[
|
||||
hash_with_group_id]
|
||||
manager.free(req)
|
||||
@@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
|
||||
# total cache miss.
|
||||
# The cache hit length of full attention is 1 * block_size.
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers have different hit length.
|
||||
# Then it is cache miss as the two type of layers
|
||||
# have different hit length.
|
||||
test_partial_request_hit("8", [
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
@@ -406,7 +412,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||
|
||||
# Check full block metadata
|
||||
@@ -441,7 +447,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@@ -478,6 +484,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req2, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks is not None
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||
@@ -513,7 +520,7 @@ def test_decode():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@@ -558,7 +565,8 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
|
||||
# 5 full + 1 partial
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 6
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
@@ -570,7 +578,7 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
# 10 - (6 + 3) == 1
|
||||
@@ -592,7 +600,7 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([10], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([10], )
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
manager.free(req)
|
||||
@@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
|
||||
assert manager.block_pool.blocks[blocks.blocks[0]
|
||||
[0].block_id].block_hash is None
|
||||
@@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req0, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
@@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req1, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
# Free the blocks.
|
||||
@@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
|
||||
@@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
||||
blocks = manager.allocate_slots(req1, 10,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 3
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 3
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req1)
|
||||
@@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
|
||||
blocks = manager.allocate_slots(req2, 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||
@@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Test that blocks that don't start from the beginning are cached correctly.
|
||||
# Test that blocks that don't start from the beginning are cached
|
||||
# correctly.
|
||||
blocks += [KVCacheBlock(block_id=2)]
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
@@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
@@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
|
||||
# Manually add all blocks to cached_blocks
|
||||
for block, block_hash in zip(pool.blocks, block_hashes):
|
||||
block.block_hash = block_hash
|
||||
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
|
||||
pool.cached_block_hash_to_block.insert(block_hash, block)
|
||||
|
||||
block0, block1, block2, block3 = pool.blocks
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block0.block_id: block0,
|
||||
block3.block_id: block3
|
||||
block3.block_id: block3,
|
||||
},
|
||||
block_hash1: {
|
||||
block1.block_id: block1
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash1: block1,
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block1
|
||||
pool._maybe_evict_cached_block(block1)
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block0.block_id: block0,
|
||||
block3.block_id: block3
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
||||
# also use the same hash
|
||||
pool._maybe_evict_cached_block(block0)
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block3.block_id: block3
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block2
|
||||
pool._maybe_evict_cached_block(block2)
|
||||
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
|
||||
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
|
||||
# Evict block3
|
||||
pool._maybe_evict_cached_block(block3)
|
||||
assert pool.cached_block_hash_to_block == {}
|
||||
assert pool.cached_block_hash_to_block._cache == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||
@@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
|
||||
# Evict the first block in the request
|
||||
assert manager.block_pool.get_cached_block(
|
||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
@@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
|
||||
# there will be no matched prefix.
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_tokens == 0
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
key1 = BlockHashWithGroupId(b"hash1")
|
||||
key2 = BlockHashWithGroupId(b"hash2")
|
||||
block0 = KVCacheBlock(0)
|
||||
block1 = KVCacheBlock(1)
|
||||
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
# key0 inserted
|
||||
cache.insert(key0, block0)
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
# key1 inserted
|
||||
cache.insert(key1, block1)
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# No block poped due to block_id mismatch
|
||||
assert cache.pop(key0, 100) is None
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# block poped with (key0, block ID 0)
|
||||
assert cache.pop(key0, 0) is block0
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# No block poped due to block_id mismatch
|
||||
assert cache.pop(key0, 1) is None
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# block poped with (key1, block ID 1)
|
||||
assert cache.pop(key1, 1) is block1
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
|
||||
|
||||
def test_block_lookup_cache_multi_blocks_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
key1 = BlockHashWithGroupId(b"hash1")
|
||||
block00 = KVCacheBlock(0)
|
||||
block01 = KVCacheBlock(1)
|
||||
block10 = KVCacheBlock(10)
|
||||
block11 = KVCacheBlock(11)
|
||||
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
|
||||
cache.insert(key0, block00)
|
||||
cache.insert(key0, block01)
|
||||
cache.insert(key1, block10)
|
||||
cache.insert(key1, block11)
|
||||
|
||||
assert cache.get_one_block(key0) is block00
|
||||
assert cache.pop(key0, 0) is block00
|
||||
assert cache.get_one_block(key0) is block01
|
||||
assert cache.pop(key0, 1) is block01
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.pop(key0, 2) is None
|
||||
|
||||
assert cache.get_one_block(key1) is block10
|
||||
assert cache.pop(key1, 10) is block10
|
||||
assert cache.get_one_block(key1) is block11
|
||||
assert cache.pop(key1, 11) is block11
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.pop(key1, 12) is None
|
||||
|
||||
Reference in New Issue
Block a user