[Hybrid Allocator] Support KV cache groups with different block_size (#29143)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Yifan Qiao
2025-11-25 07:30:57 -08:00
committed by GitHub
parent e502098643
commit 48ddb02b79
11 changed files with 472 additions and 87 deletions

View File

@@ -134,6 +134,7 @@ def test_prefill(hash_fn):
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
hash_fn = sha256
@@ -416,6 +418,7 @@ def test_prefill_plp():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# the default hash function is sha256
hash_fn = sha256
@@ -523,6 +526,7 @@ def test_decode():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@@ -585,6 +589,7 @@ def test_evict():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
last_token_id = 5 * 16 + 7
@@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Allocate 1 block and cache it.
@@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Allocate a block and cache it.
@@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
hash_block_size=block_size,
)
req1 = make_request(
@@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
block_pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
hash_block_size=block_size,
)
# Req:
# Block 0: [0, 1, 2, 3]
@@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
This tests that blocks are cached correctly for different kv cache groups.
"""
block_size = 4
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
block_pool = BlockPool(
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
)
# Req:
# Block 0/4: [0, 1, 2, 3]
@@ -921,6 +932,7 @@ def test_mm_prefix_caching():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
@@ -1020,6 +1032,7 @@ def test_cache_key_salting():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# 3 complete blocks and an incomplete block with 11 tokens.
@@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
@@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]
@@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
log_stats=False, # Disable logging stats
)
assert manager.prefix_cache_stats is None
@@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
@@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
hash_block_size=block_size,
)
num_tokens = block_size * blocks_to_cache
@@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
hash_block_size=block_size,
)
# Test with LoRA request
@@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)
# Request with 3 full blocks (48 tokens)
@@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
@@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
@@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0
def test_different_block_size():
block_size = 16
# full attention and sliding window attention layers have the same page size:
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size,
1,
1,
torch.float32,
sliding_window=2 * block_size,
),
),
],
)
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
common_token_ids = [i for i in range(10) for _ in range(block_size)]
req0 = make_request("0", common_token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert not computed_blocks.blocks[1]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
)
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
# Evict some blocks to make sliding window cache hit length 5*16
# But should return 4 * 16 because full attention cache hit length must be
# a multiple of 32
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
)
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 2
assert len(computed_blocks.blocks[1]) == 4
assert num_computed_tokens == 4 * 16
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")