[Hybrid Allocator] Support KV cache groups with different block_size (#29143)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
|
||||
)
|
||||
|
||||
# Test case 1: Requires additional lookahead tokens
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||
)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens=3,
|
||||
@@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
|
||||
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
|
||||
|
||||
# Test case 2: With precomputed blocks
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||
)
|
||||
# required_blocks = ceil((3 + 2) /4) = 2
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
@@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
|
||||
|
||||
# Test case 3: With precomputed blocks
|
||||
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||
)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens=3,
|
||||
@@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker():
|
||||
),
|
||||
],
|
||||
)
|
||||
# different hidden size
|
||||
|
||||
# different hidden size but same type, use UniformTypeKVCacheSpecs
|
||||
kv_cache_specs_hybrid = {
|
||||
"layer_1": new_kv_cache_spec(head_size=128),
|
||||
"layer_2": new_kv_cache_spec(head_size=64),
|
||||
@@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
|
||||
],
|
||||
)
|
||||
|
||||
# Different hidden size and different type, align by different block size
|
||||
kv_cache_specs_hybrid = {
|
||||
"layer_1": new_kv_cache_spec(head_size=64),
|
||||
"layer_2": new_sliding_window_spec(head_size=32),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
|
||||
)[0]
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
|
||||
KVCacheGroupSpec(
|
||||
["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# different hidden size that cannot be aligned by using different block size
|
||||
kv_cache_specs_hybrid = {
|
||||
"layer_1": new_kv_cache_spec(head_size=64),
|
||||
"layer_2": new_sliding_window_spec(head_size=96),
|
||||
}
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
|
||||
)[0]
|
||||
|
||||
# Test num_gpu_blocks_override
|
||||
vllm_config.cache_config.num_gpu_blocks_override = 16
|
||||
kv_cache_config_override_blocks = get_kv_cache_configs(
|
||||
|
||||
@@ -134,6 +134,7 @@ def test_prefill(hash_fn):
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
@@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
|
||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
hash_fn = sha256
|
||||
@@ -416,6 +418,7 @@ def test_prefill_plp():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# the default hash function is sha256
|
||||
hash_fn = sha256
|
||||
@@ -523,6 +526,7 @@ def test_decode():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
@@ -585,6 +589,7 @@ def test_evict():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
last_token_id = 5 * 16 + 7
|
||||
@@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
|
||||
make_kv_cache_config(16, 2),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Allocate 1 block and cache it.
|
||||
@@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
|
||||
make_kv_cache_config(block_size, 3),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Allocate a block and cache it.
|
||||
@@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
|
||||
make_kv_cache_config(block_size, 5),
|
||||
max_model_len=8192,
|
||||
enable_caching=False,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
req1 = make_request(
|
||||
@@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# Req:
|
||||
# Block 0: [0, 1, 2, 3]
|
||||
@@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
|
||||
This tests that blocks are cached correctly for different kv cache groups.
|
||||
"""
|
||||
block_size = 4
|
||||
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
|
||||
# Req:
|
||||
# Block 0/4: [0, 1, 2, 3]
|
||||
@@ -921,6 +932,7 @@ def test_mm_prefix_caching():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
||||
@@ -1020,6 +1032,7 @@ def test_cache_key_salting():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||
@@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# Complete 3 blocks (48 tokens)
|
||||
# | Common-0 | Common-1 | Common-2 | ... |
|
||||
@@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
log_stats=False, # Disable logging stats
|
||||
)
|
||||
assert manager.prefix_cache_stats is None
|
||||
@@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
|
||||
|
||||
|
||||
def test_maybe_evict_cached_block():
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
|
||||
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||
@@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
@@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Test with LoRA request
|
||||
@@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Request with 3 full blocks (48 tokens)
|
||||
@@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
@@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
@@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
|
||||
assert num_tokens == 0
|
||||
|
||||
|
||||
def test_different_block_size():
|
||||
block_size = 16
|
||||
# full attention and sliding window attention layers have the same page size:
|
||||
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
|
||||
common_token_ids = [i for i in range(10) for _ in range(block_size)]
|
||||
|
||||
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert not computed_blocks.blocks[1]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(
|
||||
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
||||
)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
|
||||
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert len(computed_blocks.blocks[1]) == 6
|
||||
assert num_computed_tokens == 6 * 16
|
||||
|
||||
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert len(computed_blocks.blocks[1]) == 6
|
||||
assert num_computed_tokens == 6 * 16
|
||||
|
||||
# Evict some blocks to make sliding window cache hit length 5*16
|
||||
# But should return 4 * 16 because full attention cache hit length must be
|
||||
# a multiple of 32
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
|
||||
)
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
|
||||
)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks[0]) == 2
|
||||
assert len(computed_blocks.blocks[1]) == 4
|
||||
assert num_computed_tokens == 4 * 16
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
|
||||
@@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
attention_chunk_size=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_chunked_local_attention_manager(
|
||||
chunked_local_attention_spec, block_pool
|
||||
)
|
||||
@@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=chunked_local_attention_spec,
|
||||
use_eagle=False,
|
||||
alignment_tokens=block_size,
|
||||
)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
@@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
|
||||
sliding_window=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
|
||||
def run_one_case(block_is_cached, expect_length):
|
||||
@@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=sliding_window_spec,
|
||||
use_eagle=False,
|
||||
alignment_tokens=block_size,
|
||||
)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
@@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
|
||||
attention_chunk_size=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||
|
||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||
|
||||
@@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
sliding_window=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
|
||||
@@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
|
||||
sliding_window=4, # Placeholder value, not related to test result
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||
@@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||
|
||||
Reference in New Issue
Block a user