[v1] Hybrid Memory Allocator (#17996)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compare the with and without prefix caching."""
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@@ -13,8 +14,8 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock, hash_block_tokens)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
@@ -47,7 +48,7 @@ def make_request(request_id,
|
||||
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
@@ -57,6 +58,38 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
)
|
||||
|
||||
|
||||
def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
num_blocks: int) -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer3"],
|
||||
SlidingWindowSpec(block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
manager = KVCacheManager(
|
||||
@@ -79,10 +112,10 @@ def test_prefill(hash_algo):
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
@@ -92,7 +125,8 @@ def test_prefill(hash_algo):
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
@@ -111,10 +145,10 @@ def test_prefill(hash_algo):
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 5 free blocks left.
|
||||
@@ -145,7 +179,7 @@ def test_prefill(hash_algo):
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[6]]
|
||||
|
||||
@@ -165,10 +199,10 @@ def test_prefill(hash_algo):
|
||||
# Cache miss and eviction.
|
||||
req3 = make_request("3", [99] * (16 * 10))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 16 * 10,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
|
||||
@@ -177,6 +211,138 @@ def test_prefill(hash_algo):
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
|
||||
|
||||
def test_prefill_hybrid_model():
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
hash_fn = hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
|
||||
[9, 10, 11, 12]]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for length, block_ids in zip((1, 2, 3),
|
||||
((1, 5, 9), (2, 6, 10), (3, 7, 11))):
|
||||
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
for block_id in block_ids:
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, 8, 12):
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
# Cache hit in the common prefix
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
|
||||
[0, 10, 11]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[13], [14], [15]]
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
block_hashes = manager.req_to_block_hashes[req1.request_id]
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block)
|
||||
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(manager.req_to_block_hashes[req.request_id]) == 3
|
||||
assert num_computed_tokens == expect_hit_length * block_size
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block[
|
||||
hash_with_group_id] = cached_block_hash_to_block_bak[
|
||||
hash_with_group_id]
|
||||
manager.free(req)
|
||||
|
||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||
test_partial_request_hit("2", [
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2)
|
||||
], 3)
|
||||
|
||||
# Evict the first block of full attention, makes total cache miss.
|
||||
test_partial_request_hit("3", [
|
||||
BlockHashWithGroupId(block_hashes[0], 0),
|
||||
], 0)
|
||||
|
||||
# Evict the last block of all layers, reduces the hit length to 2.
|
||||
test_partial_request_hit("4", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[2], 1),
|
||||
BlockHashWithGroupId(block_hashes[2], 2),
|
||||
], 2)
|
||||
|
||||
# Evict the last block of full attention, reduces the hit length to 2.
|
||||
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
|
||||
2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
|
||||
2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
|
||||
2)
|
||||
|
||||
# Evict different set of blocks for full attention and sliding window makes
|
||||
# total cache miss.
|
||||
# The cache hit length of full attention is 1 * block_size.
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers have different hit length.
|
||||
test_partial_request_hit("8", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2),
|
||||
], 0)
|
||||
|
||||
|
||||
def test_prefill_plp():
|
||||
'''Test prefill with APC and some prompt logprobs (plp) requests.
|
||||
|
||||
@@ -203,13 +369,13 @@ def test_prefill_plp():
|
||||
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@@ -217,7 +383,8 @@ def test_prefill_plp():
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
@@ -237,10 +404,10 @@ def test_prefill_plp():
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 5 free blocks left.
|
||||
@@ -269,14 +436,14 @@ def test_prefill_plp():
|
||||
prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
|
||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||
assert block_ids != [[1, 2, 3, 4]]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
@@ -302,10 +469,10 @@ def test_decode():
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
@@ -314,10 +481,10 @@ def test_decode():
|
||||
for _ in range(4):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 4,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
assert manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id][-1].block_hash is None
|
||||
|
||||
# Append slots with allocating a new block.
|
||||
@@ -327,12 +494,12 @@ def test_decode():
|
||||
for _ in range(9 + 10):
|
||||
req0.append_output_token_ids(7)
|
||||
new_blocks = manager.allocate_slots(req0, 19,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 1
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 1
|
||||
assert manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id][-2].block_hash is not None
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
assert manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id][-1].block_hash is None
|
||||
|
||||
|
||||
@@ -346,23 +513,23 @@ def test_evict():
|
||||
last_token_id = 5 * 16 + 7
|
||||
req0 = make_request("0", list(range(last_token_id)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 6 # 5 full + 1 partial
|
||||
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
last_token_id + 3 * 16)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 3 # 3 full blocks
|
||||
assert len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
# 10 - (6 + 3) == 1
|
||||
@@ -382,7 +549,7 @@ def test_evict():
|
||||
assert computed_blocks.get_block_ids() == [[1, 2]]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[10]]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
@@ -404,12 +571,12 @@ def test_hash_block_correct_reuse():
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
manager.free(req)
|
||||
@@ -418,15 +585,15 @@ def test_hash_block_correct_reuse():
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
|
||||
assert manager.block_pool.blocks[
|
||||
blocks.blocks[0].block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[blocks.blocks[0]
|
||||
[0].block_id].block_hash is None
|
||||
|
||||
|
||||
def test_computed_blocks_not_evicted():
|
||||
@@ -445,24 +612,24 @@ def test_computed_blocks_not_evicted():
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 1
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
@@ -472,15 +639,15 @@ def test_computed_blocks_not_evicted():
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert computed_blocks.blocks[0].block_id == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert computed_blocks.blocks[0][0].block_id == 1
|
||||
assert num_computed_tokens == block_size
|
||||
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
|
||||
def test_basic_prefix_caching_disabled():
|
||||
@@ -497,12 +664,12 @@ def test_basic_prefix_caching_disabled():
|
||||
req1 = make_request("1", list(range(10))) # 2 blocks and some more
|
||||
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 10,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 3
|
||||
assert len(blocks.blocks[0]) == 3
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req1)
|
||||
@@ -510,20 +677,20 @@ def test_basic_prefix_caching_disabled():
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16))) # shared prefix
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 16,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 4
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 4,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert not blocks
|
||||
|
||||
@@ -558,6 +725,7 @@ def test_cache_blocks(hash_fn):
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
@@ -573,11 +741,83 @@ def test_cache_blocks(hash_fn):
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
||||
|
||||
def test_cache_blocks_multi_group():
|
||||
"""
|
||||
This tests that blocks are cached correctly for different kv cache groups.
|
||||
"""
|
||||
block_size = 4
|
||||
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
||||
|
||||
# Req:
|
||||
# Block 0/4: [0, 1, 2, 3]
|
||||
# Block 1/5: [4, 5, 6, 7]
|
||||
# Block 2/6: [8, 9, 10, 11]
|
||||
# Block 3/7: [12, 13]
|
||||
req = make_request("0", list(range(14)))
|
||||
|
||||
# Cache the blocks for group 0.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
block_hashes: list[BlockHash] = []
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert len(block_hashes) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Cache the blocks for group 1.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(3)]
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash,
|
||||
kv_cache_group_id=1,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 5
|
||||
assert len(block_hashes) == 3
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Block hash 0: hit for group 0 and 1
|
||||
# Block hash 1: hit for group 0 and 1
|
||||
# Block hash 2: hit for group 1
|
||||
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
kv_cache_group_ids=[0]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
kv_cache_group_ids=[0]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
kv_cache_group_ids=[0]) is None
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
kv_cache_group_ids=[0, 1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
kv_cache_group_ids=[0, 1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
kv_cache_group_ids=[0, 1]) is None
|
||||
|
||||
|
||||
def test_mm_prefix_caching():
|
||||
"""
|
||||
This tests that the multi-modal prefix caching is correct.
|
||||
@@ -614,7 +854,7 @@ def test_mm_prefix_caching():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@@ -623,7 +863,7 @@ def test_mm_prefix_caching():
|
||||
assert block_hashes[2].extra_keys == ("bbb", )
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
@@ -632,9 +872,9 @@ def test_mm_prefix_caching():
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# The just completed block should have hashes with extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
@@ -652,7 +892,7 @@ def test_mm_prefix_caching():
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert num_computed_tokens == 3 * 16
|
||||
|
||||
|
||||
@@ -675,7 +915,7 @@ def test_cache_key_salting():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@@ -684,7 +924,7 @@ def test_cache_key_salting():
|
||||
assert block_hashes[2].extra_keys is None
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
@@ -693,9 +933,9 @@ def test_cache_key_salting():
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# Now one more block that should not have extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
@@ -706,14 +946,14 @@ def test_cache_key_salting():
|
||||
req1 = make_request("1", token_ids, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Should match only a prefix of 3 blocks.
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert num_computed_tokens == 3 * block_size
|
||||
|
||||
# Test cache miss with same content but different salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req2 = make_request("2", token_ids, cache_salt="salt2")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks) == 0
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req2.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@@ -738,20 +978,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
req0 = make_request("0", common_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req0, 48,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks.blocks == block_part0
|
||||
assert computed_blocks.blocks[0] == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
manager.allocate_slots(req1, 48,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req1.request_id]
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| ... |
|
||||
manager.free(req1)
|
||||
@@ -762,10 +1006,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||
req2 = make_request("2", [7] * block_size * 2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req2, block_size * 2,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
|
||||
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
@@ -773,11 +1018,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks.blocks == block_part1
|
||||
assert computed_blocks.blocks[0] == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
# Req3 cannot be allocated.
|
||||
assert manager.allocate_slots(req3, 48,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks) is None
|
||||
# Block 0-2 are used by Req 1.
|
||||
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
||||
@@ -804,9 +1049,9 @@ def test_reset_prefix_cache():
|
||||
req1 = make_request("1", all_token_ids)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
|
||||
@@ -836,10 +1081,11 @@ def test_prefix_cache_stats_disabled():
|
||||
# Call all functions that check whether log_stats is disabled.
|
||||
req = make_request("0", list(range(16)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req, 16,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
manager.reset_prefix_cache()
|
||||
|
||||
# Ensure prefix_cache_stats remains None
|
||||
@@ -918,7 +1164,8 @@ def test_eagle_enabled_removes_last_block():
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with same tokens + Eagle enabled
|
||||
@@ -928,7 +1175,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
# Should retain 1 block:
|
||||
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
||||
# 2. drop last matched block → 1 remaining block
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert num_tokens == 1 * block_size # 16 tokens
|
||||
|
||||
|
||||
@@ -948,14 +1195,15 @@ def test_eagle_with_partial_blocks():
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert num_tokens == 1 * block_size
|
||||
|
||||
|
||||
@@ -973,7 +1221,7 @@ def test_eagle_with_sliding_window():
|
||||
manager = KVCacheManager(
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
|
||||
),
|
||||
max_model_len=8192,
|
||||
@@ -988,7 +1236,8 @@ def test_eagle_with_sliding_window():
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# record the block hash of the first block in the request for later use
|
||||
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
|
||||
assert block_hash_first_block is not None
|
||||
@@ -998,13 +1247,14 @@ def test_eagle_with_sliding_window():
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert num_tokens == 1 * block_size
|
||||
|
||||
# Evict the first block in the request
|
||||
assert manager.block_pool.get_cached_block(
|
||||
block_hash_first_block) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
|
||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
BlockHashWithGroupId(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
|
||||
@@ -1012,5 +1262,5 @@ def test_eagle_with_sliding_window():
|
||||
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||
# not considered. But after dropping the last matched block due to eagle,
|
||||
# there will be no matched prefix.
|
||||
assert len(computed_blocks.blocks) == 0
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_tokens == 0
|
||||
|
||||
Reference in New Issue
Block a user