[Core] Use tuple for kv cache group block ids (#19175)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -117,7 +117,7 @@ def test_prefill(hash_algo):
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@@ -141,13 +141,13 @@ def test_prefill(hash_algo):
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@@ -175,13 +175,13 @@ def test_prefill(hash_algo):
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[6]]
|
||||
assert blocks.get_block_ids() == ([6], )
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@@ -205,7 +205,7 @@ def test_prefill(hash_algo):
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
|
||||
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
@@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
|
||||
[9, 10, 11, 12]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
||||
8], [9, 10, 11, 12])
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
|
||||
[0, 10, 11]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
|
||||
7], [0, 10, 11])
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[13], [14], [15]]
|
||||
assert blocks.get_block_ids() == ([13], [14], [15])
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
@@ -374,7 +374,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||
|
||||
# Check full block metadata
|
||||
@@ -400,13 +400,13 @@ def test_prefill_plp():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@@ -444,7 +444,7 @@ def test_prefill_plp():
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||
assert block_ids != [[1, 2, 3, 4]]
|
||||
assert block_ids != ([1, 2, 3, 4], )
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
# Check block reference counts.
|
||||
@@ -474,7 +474,7 @@ def test_decode():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@@ -546,12 +546,12 @@ def test_evict():
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == [[1, 2]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2], )
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[10]]
|
||||
assert blocks.get_block_ids() == ([10], )
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@@ -865,7 +865,7 @@ def test_mm_prefix_caching():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@@ -926,7 +926,7 @@ def test_cache_key_salting():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
|
||||
Reference in New Issue
Block a user