[Hybrid Allocator] Support KV cache groups with different block_size (#29143)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -134,6 +134,7 @@ def test_prefill(hash_fn):
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
@@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
|
||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
hash_fn = sha256
|
||||
@@ -416,6 +418,7 @@ def test_prefill_plp():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# the default hash function is sha256
|
||||
hash_fn = sha256
|
||||
@@ -523,6 +526,7 @@ def test_decode():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
@@ -585,6 +589,7 @@ def test_evict():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
last_token_id = 5 * 16 + 7
|
||||
@@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
|
||||
make_kv_cache_config(16, 2),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Allocate 1 block and cache it.
|
||||
@@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
|
||||
make_kv_cache_config(block_size, 3),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Allocate a block and cache it.
|
||||
@@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
|
||||
make_kv_cache_config(block_size, 5),
|
||||
max_model_len=8192,
|
||||
enable_caching=False,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
req1 = make_request(
|
||||
@@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# Req:
|
||||
# Block 0: [0, 1, 2, 3]
|
||||
@@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
|
||||
This tests that blocks are cached correctly for different kv cache groups.
|
||||
"""
|
||||
block_size = 4
|
||||
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
|
||||
# Req:
|
||||
# Block 0/4: [0, 1, 2, 3]
|
||||
@@ -921,6 +932,7 @@ def test_mm_prefix_caching():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
||||
@@ -1020,6 +1032,7 @@ def test_cache_key_salting():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||
@@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# Complete 3 blocks (48 tokens)
|
||||
# | Common-0 | Common-1 | Common-2 | ... |
|
||||
@@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
log_stats=False, # Disable logging stats
|
||||
)
|
||||
assert manager.prefix_cache_stats is None
|
||||
@@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
|
||||
|
||||
|
||||
def test_maybe_evict_cached_block():
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
|
||||
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||
@@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
@@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Test with LoRA request
|
||||
@@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Request with 3 full blocks (48 tokens)
|
||||
@@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
@@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
@@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
|
||||
assert num_tokens == 0
|
||||
|
||||
|
||||
def test_different_block_size():
|
||||
block_size = 16
|
||||
# full attention and sliding window attention layers have the same page size:
|
||||
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
|
||||
common_token_ids = [i for i in range(10) for _ in range(block_size)]
|
||||
|
||||
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert not computed_blocks.blocks[1]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(
|
||||
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
||||
)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
|
||||
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert len(computed_blocks.blocks[1]) == 6
|
||||
assert num_computed_tokens == 6 * 16
|
||||
|
||||
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert len(computed_blocks.blocks[1]) == 6
|
||||
assert num_computed_tokens == 6 * 16
|
||||
|
||||
# Evict some blocks to make sliding window cache hit length 5*16
|
||||
# But should return 4 * 16 because full attention cache hit length must be
|
||||
# a multiple of 32
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
|
||||
)
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
|
||||
)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks[0]) == 2
|
||||
assert len(computed_blocks.blocks[1]) == 4
|
||||
assert num_computed_tokens == 4 * 16
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
|
||||
Reference in New Issue
Block a user