[v1] Support multiple KV cache groups in GPU model runner (#17945)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
||||
hash_request_tokens,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False):
|
||||
use_mla=False,
|
||||
sliding_window=None):
|
||||
return FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla)
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
def test_none_hash():
|
||||
@@ -471,6 +474,68 @@ def test_unify_kv_cache_configs():
|
||||
unify_kv_cache_configs(diff_kv_cache_config)
|
||||
|
||||
|
||||
def test_merge_kv_cache_spec():
|
||||
same_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
]
|
||||
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
|
||||
assert merged_layer_spec.block_size == 16
|
||||
assert merged_layer_spec.num_kv_heads == 32
|
||||
assert merged_layer_spec.head_size == 64
|
||||
assert merged_layer_spec.dtype == torch.float32
|
||||
assert merged_layer_spec.sliding_window is None
|
||||
|
||||
different_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
new_kv_cache_spec(num_kv_heads=16),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
different_layer_specs[0].merge(different_layer_specs)
|
||||
|
||||
full_spec = new_kv_cache_spec(num_kv_heads=32)
|
||||
different_type_layer_specs = [
|
||||
full_spec,
|
||||
SlidingWindowSpec(
|
||||
block_size=full_spec.block_size,
|
||||
num_kv_heads=full_spec.num_kv_heads,
|
||||
head_size=full_spec.head_size,
|
||||
dtype=full_spec.dtype,
|
||||
use_mla=full_spec.use_mla,
|
||||
sliding_window=1,
|
||||
),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
different_type_layer_specs[0].merge(different_type_layer_specs)
|
||||
with pytest.raises(AssertionError):
|
||||
different_type_layer_specs[1].merge(different_type_layer_specs)
|
||||
|
||||
different_sliding_window_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
different_sliding_window_layer_specs[0].merge(
|
||||
different_sliding_window_layer_specs)
|
||||
|
||||
same_sliding_window_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
]
|
||||
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
|
||||
same_sliding_window_layer_specs)
|
||||
assert merged_layer_spec.sliding_window == 1
|
||||
|
||||
same_sliding_window_layer_spec_with_none = [
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
|
||||
]
|
||||
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
|
||||
same_sliding_window_layer_spec_with_none)
|
||||
assert merged_layer_spec.sliding_window == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "max_model_len", "want_estimated_max_len"), [
|
||||
("Qwen/Qwen1.5-7B", 16385, 16384),
|
||||
|
||||
@@ -84,7 +84,7 @@ def test_prefill(hash_algo):
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@@ -107,13 +107,13 @@ def test_prefill(hash_algo):
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [6]
|
||||
assert blocks.get_block_ids() == [[6]]
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@@ -171,7 +171,7 @@ def test_prefill(hash_algo):
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
|
||||
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
@@ -208,7 +208,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks]
|
||||
|
||||
# Check full block metadata
|
||||
@@ -233,13 +233,13 @@ def test_prefill_plp():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@@ -277,11 +277,11 @@ def test_prefill_plp():
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
|
||||
assert block_ids != [1, 2, 3, 4]
|
||||
assert block_ids != [[1, 2, 3, 4]]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
# Check block reference counts.
|
||||
for block_id in block_ids:
|
||||
for block_id in block_ids[0]:
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
manager.free(req2)
|
||||
@@ -307,7 +307,7 @@ def test_decode():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@@ -379,12 +379,12 @@ def test_evict():
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == [1, 2]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2]]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [10]
|
||||
assert blocks.get_block_ids() == [[10]]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@@ -625,7 +625,7 @@ def test_mm_prefix_caching():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@@ -686,7 +686,7 @@ def test_cache_key_salting():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@@ -797,7 +797,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@@ -808,7 +808,7 @@ def test_reset_prefix_cache():
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
|
||||
Reference in New Issue
Block a user