|
|
|
|
@@ -79,10 +79,10 @@ def test_prefill(hash_algo):
|
|
|
|
|
req0 = make_request("0", all_token_ids)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
|
|
|
|
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
|
|
|
|
|
|
|
|
|
# Check full block metadata
|
|
|
|
|
parent_block_hash = None
|
|
|
|
|
@@ -105,12 +105,12 @@ def test_prefill(hash_algo):
|
|
|
|
|
req1 = make_request("1", common_token_ids + unique_token_ids)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
|
|
|
|
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
|
|
|
|
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
|
|
|
|
assert num_computed_tokens == 3 * 16
|
|
|
|
|
num_new_tokens = 53 - 3 * 16
|
|
|
|
|
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [5]
|
|
|
|
|
for block in computed_blocks:
|
|
|
|
|
assert blocks.get_block_ids() == [5]
|
|
|
|
|
for block in computed_blocks.blocks:
|
|
|
|
|
assert block.ref_cnt == 2
|
|
|
|
|
|
|
|
|
|
# At this point, we should have 5 free blocks left.
|
|
|
|
|
@@ -137,11 +137,11 @@ def test_prefill(hash_algo):
|
|
|
|
|
req2 = make_request("2", common_token_ids + unique_token_ids)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
|
|
|
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
|
|
|
|
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
|
|
|
|
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
|
|
|
|
assert num_computed_tokens == 3 * 16
|
|
|
|
|
num_new_tokens = 53 - 3 * 16
|
|
|
|
|
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [6]
|
|
|
|
|
assert blocks.get_block_ids() == [6]
|
|
|
|
|
|
|
|
|
|
# Although we only have 6 free blocks, we have 8 blocks in
|
|
|
|
|
# the free block queue due to lazy removal.
|
|
|
|
|
@@ -159,11 +159,11 @@ def test_prefill(hash_algo):
|
|
|
|
|
# Cache miss and eviction.
|
|
|
|
|
req3 = make_request("3", [99] * (16 * 10))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
|
|
|
|
|
# This block ID order also checks the eviction order.
|
|
|
|
|
assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
|
|
|
|
|
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
|
|
|
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
|
|
|
|
assert manager.block_pool.free_block_queue.free_list_head is None
|
|
|
|
|
assert manager.block_pool.free_block_queue.free_list_tail is None
|
|
|
|
|
@@ -195,11 +195,11 @@ def test_prefill_plp():
|
|
|
|
|
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
|
|
|
|
req0_block_hashes = [b.block_hash for b in blocks]
|
|
|
|
|
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
|
|
|
|
req0_block_hashes = [b.block_hash for b in blocks.blocks]
|
|
|
|
|
|
|
|
|
|
# Check full block metadata
|
|
|
|
|
parent_block_hash = None
|
|
|
|
|
@@ -223,12 +223,12 @@ def test_prefill_plp():
|
|
|
|
|
req1 = make_request("1", common_token_ids + unique_token_ids)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
|
|
|
|
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
|
|
|
|
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
|
|
|
|
assert num_computed_tokens == 3 * 16
|
|
|
|
|
num_new_tokens = 53 - 3 * 16
|
|
|
|
|
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [5]
|
|
|
|
|
for block in computed_blocks:
|
|
|
|
|
assert blocks.get_block_ids() == [5]
|
|
|
|
|
for block in computed_blocks.blocks:
|
|
|
|
|
assert block.ref_cnt == 2
|
|
|
|
|
|
|
|
|
|
# At this point, we should have 5 free blocks left.
|
|
|
|
|
@@ -257,12 +257,12 @@ def test_prefill_plp():
|
|
|
|
|
prompt_logprobs=5)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
|
|
|
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req2, 55, computed_blocks)
|
|
|
|
|
block_ids = [b.block_id for b in blocks]
|
|
|
|
|
block_ids = blocks.get_block_ids()
|
|
|
|
|
# Duplicate cached blocks have different ids but same hashes vs request #0
|
|
|
|
|
assert [b.block_hash for b in blocks] == req0_block_hashes
|
|
|
|
|
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
|
|
|
|
|
assert block_ids != [1, 2, 3, 4]
|
|
|
|
|
|
|
|
|
|
# Request #2 block hashes are valid since request #0 hashes are.
|
|
|
|
|
@@ -288,17 +288,17 @@ def test_decode():
|
|
|
|
|
unique_token_ids = [3] * 7
|
|
|
|
|
req0 = make_request("0", common_token_ids + unique_token_ids)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
|
|
|
|
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
|
|
|
|
|
|
|
|
|
# Append slots without allocating a new block.
|
|
|
|
|
req0.num_computed_tokens = 55
|
|
|
|
|
for _ in range(4):
|
|
|
|
|
req0.append_output_token_ids(8)
|
|
|
|
|
new_blocks = manager.allocate_slots(req0, 4)
|
|
|
|
|
assert new_blocks is not None and len(new_blocks) == 0
|
|
|
|
|
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
|
|
|
|
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
|
|
|
|
|
|
|
|
|
|
# Append slots with allocating a new block.
|
|
|
|
|
@@ -308,7 +308,7 @@ def test_decode():
|
|
|
|
|
for _ in range(9 + 10):
|
|
|
|
|
req0.append_output_token_ids(7)
|
|
|
|
|
new_blocks = manager.allocate_slots(req0, 19)
|
|
|
|
|
assert new_blocks is not None and len(new_blocks) == 1
|
|
|
|
|
assert new_blocks is not None and len(new_blocks.blocks) == 1
|
|
|
|
|
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
|
|
|
|
|
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
|
|
|
|
|
|
|
|
|
|
@@ -323,19 +323,19 @@ def test_evict():
|
|
|
|
|
last_token_id = 5 * 16 + 7
|
|
|
|
|
req0 = make_request("0", list(range(last_token_id)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
|
|
|
|
|
assert len(blocks) == 6 # 5 full + 1 partial
|
|
|
|
|
assert len(blocks.blocks) == 6 # 5 full + 1 partial
|
|
|
|
|
|
|
|
|
|
# 3 blocks.
|
|
|
|
|
req1 = make_request("1", list(range(last_token_id,
|
|
|
|
|
last_token_id + 3 * 16)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
|
|
|
|
|
assert len(blocks) == 3 # 3 full blocks
|
|
|
|
|
assert len(blocks.blocks) == 3 # 3 full blocks
|
|
|
|
|
last_token_id += 3 * 16
|
|
|
|
|
|
|
|
|
|
# 10 - (6 + 3) == 1
|
|
|
|
|
@@ -352,10 +352,10 @@ def test_evict():
|
|
|
|
|
# Touch the first 2 blocks.
|
|
|
|
|
req2 = make_request("2", list(range(2 * 16 + 3)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
|
|
|
assert [b.block_id for b in computed_blocks] == [1, 2]
|
|
|
|
|
assert computed_blocks.get_block_ids() == [1, 2]
|
|
|
|
|
assert num_computed_tokens == 2 * 16
|
|
|
|
|
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [10]
|
|
|
|
|
assert blocks.get_block_ids() == [10]
|
|
|
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -375,10 +375,10 @@ def test_hash_block_correct_reuse():
|
|
|
|
|
num_tokens = block_size * 1
|
|
|
|
|
req = make_request("0", list(range(num_tokens)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
|
|
|
|
|
assert len(blocks) == 1
|
|
|
|
|
assert len(blocks.blocks) == 1
|
|
|
|
|
|
|
|
|
|
# Deallocate the block.
|
|
|
|
|
manager.free(req)
|
|
|
|
|
@@ -387,12 +387,13 @@ def test_hash_block_correct_reuse():
|
|
|
|
|
# block is cleared.
|
|
|
|
|
req = make_request("1", list(range(num_tokens - 1)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
|
|
|
|
|
assert len(blocks) == 1
|
|
|
|
|
assert len(blocks.blocks) == 1
|
|
|
|
|
|
|
|
|
|
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
|
|
|
|
|
assert manager.block_pool.blocks[
|
|
|
|
|
blocks.blocks[0].block_id].block_hash is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_computed_blocks_not_evicted():
|
|
|
|
|
@@ -411,20 +412,20 @@ def test_computed_blocks_not_evicted():
|
|
|
|
|
num_tokens = block_size * 1
|
|
|
|
|
req0 = make_request("0", list(range(num_tokens)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
|
|
|
|
|
assert len(blocks) == 1
|
|
|
|
|
assert blocks[0].block_id == 1
|
|
|
|
|
assert len(blocks.blocks) == 1
|
|
|
|
|
assert blocks.blocks[0].block_id == 1
|
|
|
|
|
|
|
|
|
|
# Allocate another block.
|
|
|
|
|
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
|
|
|
|
|
assert len(blocks) == 1
|
|
|
|
|
assert blocks[0].block_id == 2
|
|
|
|
|
assert len(blocks.blocks) == 1
|
|
|
|
|
assert blocks.blocks[0].block_id == 2
|
|
|
|
|
|
|
|
|
|
# Free the blocks.
|
|
|
|
|
manager.free(req0)
|
|
|
|
|
@@ -434,14 +435,14 @@ def test_computed_blocks_not_evicted():
|
|
|
|
|
# cached block rather than the first one.
|
|
|
|
|
req2 = make_request("2", list(range(num_tokens * 2)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
|
|
|
assert len(computed_blocks) == 1
|
|
|
|
|
assert computed_blocks[0].block_id == 1
|
|
|
|
|
assert len(computed_blocks.blocks) == 1
|
|
|
|
|
assert computed_blocks.blocks[0].block_id == 1
|
|
|
|
|
assert num_computed_tokens == block_size
|
|
|
|
|
|
|
|
|
|
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
|
|
|
|
computed_blocks)
|
|
|
|
|
assert len(blocks) == 1
|
|
|
|
|
assert blocks[0].block_id == 2
|
|
|
|
|
assert len(blocks.blocks) == 1
|
|
|
|
|
assert blocks.blocks[0].block_id == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_basic_prefix_caching_disabled():
|
|
|
|
|
@@ -458,10 +459,10 @@ def test_basic_prefix_caching_disabled():
|
|
|
|
|
req1 = make_request("1", list(range(10))) # 2 blocks and some more
|
|
|
|
|
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req1, 10, computed_blocks)
|
|
|
|
|
assert len(blocks) == 3
|
|
|
|
|
assert len(blocks.blocks) == 3
|
|
|
|
|
|
|
|
|
|
# Free the blocks.
|
|
|
|
|
manager.free(req1)
|
|
|
|
|
@@ -469,15 +470,15 @@ def test_basic_prefix_caching_disabled():
|
|
|
|
|
# No caching.
|
|
|
|
|
req2 = make_request("2", list(range(16))) # shared prefix
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req2, 16, computed_blocks)
|
|
|
|
|
assert len(blocks) == 4
|
|
|
|
|
assert len(blocks.blocks) == 4
|
|
|
|
|
|
|
|
|
|
# New requests should not have any blocks.
|
|
|
|
|
req3 = make_request("3", list(range(4)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
blocks = manager.allocate_slots(req3, 4, computed_blocks)
|
|
|
|
|
assert not blocks
|
|
|
|
|
@@ -569,7 +570,7 @@ def test_mm_prefix_caching():
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
|
|
|
|
|
# Completed block should have hashes with extra keys.
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
|
|
|
|
assert len(block_hashes) == 3
|
|
|
|
|
@@ -578,14 +579,14 @@ def test_mm_prefix_caching():
|
|
|
|
|
assert block_hashes[2].extra_keys == ("bbb", )
|
|
|
|
|
|
|
|
|
|
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
|
|
|
|
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
|
|
|
|
req0.num_computed_tokens = 59
|
|
|
|
|
|
|
|
|
|
# Append slots without allocating a new block.
|
|
|
|
|
for _ in range(5):
|
|
|
|
|
req0.append_output_token_ids(8)
|
|
|
|
|
new_blocks = manager.allocate_slots(req0, 5)
|
|
|
|
|
assert new_blocks is not None and len(new_blocks) == 0
|
|
|
|
|
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
|
|
|
|
|
|
|
|
|
# The just completed block should have hashes with extra keys.
|
|
|
|
|
assert len(block_hashes) == 4
|
|
|
|
|
@@ -603,7 +604,7 @@ def test_mm_prefix_caching():
|
|
|
|
|
mm_positions=mm_positions,
|
|
|
|
|
mm_hashes=mm_hashes)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
assert len(computed_blocks) == 3
|
|
|
|
|
assert len(computed_blocks.blocks) == 3
|
|
|
|
|
assert num_computed_tokens == 3 * 16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -626,7 +627,7 @@ def test_cache_key_salting():
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
|
|
|
|
|
# Completed block should have hashes with extra keys.
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
|
|
|
|
assert len(block_hashes) == 3
|
|
|
|
|
@@ -635,14 +636,14 @@ def test_cache_key_salting():
|
|
|
|
|
assert block_hashes[2].extra_keys is None
|
|
|
|
|
|
|
|
|
|
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
|
|
|
|
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
|
|
|
|
req0.num_computed_tokens = 59
|
|
|
|
|
|
|
|
|
|
# Append slots without allocating a new block.
|
|
|
|
|
for _ in range(5):
|
|
|
|
|
req0.append_output_token_ids(8)
|
|
|
|
|
new_blocks = manager.allocate_slots(req0, 5)
|
|
|
|
|
assert new_blocks is not None and len(new_blocks) == 0
|
|
|
|
|
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
|
|
|
|
|
|
|
|
|
# Now one more block that should not have extra keys.
|
|
|
|
|
assert len(block_hashes) == 4
|
|
|
|
|
@@ -653,14 +654,14 @@ def test_cache_key_salting():
|
|
|
|
|
req1 = make_request("1", token_ids, cache_salt="salt1")
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
# Should match only a prefix of 3 blocks.
|
|
|
|
|
assert len(computed_blocks) == 3
|
|
|
|
|
assert len(computed_blocks.blocks) == 3
|
|
|
|
|
assert num_computed_tokens == 3 * block_size
|
|
|
|
|
|
|
|
|
|
# Test cache miss with same content but different salt.
|
|
|
|
|
token_ids = common_token_ids + [4] * 11
|
|
|
|
|
req2 = make_request("2", token_ids, cache_salt="salt2")
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
|
|
|
assert len(computed_blocks) == 0
|
|
|
|
|
assert len(computed_blocks.blocks) == 0
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
block_hashes = manager.req_to_block_hashes[req2.request_id]
|
|
|
|
|
assert len(block_hashes) == 3
|
|
|
|
|
@@ -685,7 +686,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|
|
|
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
|
|
|
|
req0 = make_request("0", common_token_ids)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
manager.allocate_slots(req0, 48, computed_blocks)
|
|
|
|
|
block_part0 = manager.req_to_blocks[req0.request_id]
|
|
|
|
|
@@ -693,7 +694,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|
|
|
|
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
|
|
|
|
req1 = make_request("1", common_token_ids * 2)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
|
|
|
|
assert computed_blocks == block_part0
|
|
|
|
|
assert computed_blocks.blocks == block_part0
|
|
|
|
|
assert num_computed_tokens == 3 * 16
|
|
|
|
|
manager.allocate_slots(req1, 48, computed_blocks)
|
|
|
|
|
block_part1 = manager.req_to_blocks[req1.request_id]
|
|
|
|
|
@@ -707,7 +708,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|
|
|
|
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
|
|
|
|
req2 = make_request("2", [7] * block_size * 2)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
manager.allocate_slots(req2, block_size * 2, computed_blocks)
|
|
|
|
|
|
|
|
|
|
@@ -717,7 +718,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|
|
|
|
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
|
|
|
|
req3 = make_request("3", common_token_ids * 3)
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
|
|
|
|
assert computed_blocks == block_part1
|
|
|
|
|
assert computed_blocks.blocks == block_part1
|
|
|
|
|
assert num_computed_tokens == 6 * 16
|
|
|
|
|
# Req3 cannot be allocated.
|
|
|
|
|
assert manager.allocate_slots(req3, 48, computed_blocks) is None
|
|
|
|
|
@@ -739,16 +740,16 @@ def test_reset_prefix_cache():
|
|
|
|
|
all_token_ids = full_block_token_ids + unique_token_ids
|
|
|
|
|
req0 = make_request("0", all_token_ids)
|
|
|
|
|
blocks = manager.allocate_slots(req0, 55)
|
|
|
|
|
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
|
|
|
|
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
|
|
|
|
|
|
|
|
|
unique_token_ids = [4] * 7
|
|
|
|
|
all_token_ids = full_block_token_ids + unique_token_ids
|
|
|
|
|
req1 = make_request("1", all_token_ids)
|
|
|
|
|
computed_blocks, _ = manager.get_computed_blocks(req1)
|
|
|
|
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
|
|
|
|
assert len(computed_blocks) == 3
|
|
|
|
|
assert len(computed_blocks.blocks) == 3
|
|
|
|
|
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
|
|
|
|
assert [b.block_id for b in blocks] == [5]
|
|
|
|
|
assert blocks.get_block_ids() == [5]
|
|
|
|
|
|
|
|
|
|
# Failed to reset prefix cache because some blocks are not freed yet.
|
|
|
|
|
assert not manager.reset_prefix_cache()
|
|
|
|
|
@@ -776,7 +777,7 @@ def test_prefix_cache_stats_disabled():
|
|
|
|
|
# Call all functions that check whether log_stats is disabled.
|
|
|
|
|
req = make_request("0", list(range(16)))
|
|
|
|
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
|
|
|
assert not computed_blocks
|
|
|
|
|
assert not computed_blocks.blocks
|
|
|
|
|
assert num_computed_tokens == 0
|
|
|
|
|
manager.allocate_slots(req, 16, computed_blocks)
|
|
|
|
|
manager.reset_prefix_cache()
|
|
|
|
|
@@ -866,7 +867,7 @@ def test_eagle_enabled_removes_last_block():
|
|
|
|
|
# Should retain 1 block:
|
|
|
|
|
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
|
|
|
|
# 2. drop last matched block → 1 remaining block
|
|
|
|
|
assert len(computed_blocks) == 1
|
|
|
|
|
assert len(computed_blocks.blocks) == 1
|
|
|
|
|
assert num_tokens == 1 * block_size # 16 tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -892,7 +893,7 @@ def test_eagle_with_partial_blocks():
|
|
|
|
|
req_eagle = make_request("partial_eagle", token_ids)
|
|
|
|
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
|
|
|
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
|
|
|
|
assert len(computed_blocks) == 1
|
|
|
|
|
assert len(computed_blocks.blocks) == 1
|
|
|
|
|
assert num_tokens == 1 * block_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -934,7 +935,7 @@ def test_eagle_with_sliding_window():
|
|
|
|
|
req_eagle = make_request("partial_eagle", token_ids)
|
|
|
|
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
|
|
|
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
|
|
|
|
assert len(computed_blocks) == 1
|
|
|
|
|
assert len(computed_blocks.blocks) == 1
|
|
|
|
|
assert num_tokens == 1 * block_size
|
|
|
|
|
|
|
|
|
|
# Evict the first block in the request
|
|
|
|
|
@@ -948,5 +949,5 @@ def test_eagle_with_sliding_window():
|
|
|
|
|
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
|
|
|
|
# not considered. But after dropping the last matched block due to eagle,
|
|
|
|
|
# there will be no matched prefix.
|
|
|
|
|
assert len(computed_blocks) == 0
|
|
|
|
|
assert len(computed_blocks.blocks) == 0
|
|
|
|
|
assert num_tokens == 0
|
|
|
|
|
|