[v1] Move block pool operations to a separate class (#13973)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the with and without prefix caching."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
hash_block_tokens)
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
@@ -62,14 +66,14 @@ def test_prefill():
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
assert manager.block_pool[block_id].block_hash == block_hash
|
||||
assert manager.block_pool[block_id].ref_cnt == 1
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
assert manager.block_pool[block_id].block_hash is None
|
||||
assert manager.block_pool[block_id].ref_cnt == 1
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
# Cache hit in the common prefix when the original block is still in use.
|
||||
# Incomplete 1 block (5 tokens)
|
||||
@@ -86,20 +90,21 @@ def test_prefill():
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 3 free blocks left.
|
||||
assert manager.free_block_queue.num_free_blocks == 3
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 3
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
# All blocks should be available.
|
||||
assert manager.free_block_queue.num_free_blocks == 10
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
# The order should be
|
||||
# [unallocated (7, 8, 9)]
|
||||
# [unique_req0 (4, 3)]
|
||||
# [unique_req1 (6, 5)]
|
||||
# [common (2, 1, 0)]
|
||||
assert [
|
||||
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
|
||||
|
||||
# Cache hit in the common prefix when the original block is already free.
|
||||
@@ -116,12 +121,14 @@ def test_prefill():
|
||||
|
||||
# Although we only have 5 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
assert manager.free_block_queue.num_free_blocks == 5
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
assert all([
|
||||
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
|
||||
b.ref_cnt == 0
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
])
|
||||
assert len([b
|
||||
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
|
||||
assert len([
|
||||
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
]) == 5
|
||||
|
||||
manager.free(req2)
|
||||
|
||||
@@ -133,9 +140,9 @@ def test_prefill():
|
||||
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
|
||||
assert manager.free_block_queue.num_free_blocks == 0
|
||||
assert manager.free_block_queue.free_list_head is None
|
||||
assert manager.free_block_queue.free_list_tail is None
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
|
||||
|
||||
def test_decode():
|
||||
@@ -219,13 +226,14 @@ def test_evict():
|
||||
assert len(blocks) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
assert manager.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
assert manager.free_block_queue.num_free_blocks == 10
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
assert [
|
||||
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
|
||||
|
||||
# Touch the first 2 blocks.
|
||||
@@ -235,7 +243,7 @@ def test_evict():
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [6, 5]
|
||||
assert manager.free_block_queue.num_free_blocks == 6
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 6
|
||||
|
||||
|
||||
def test_hash_block_correct_reuse():
|
||||
@@ -274,7 +282,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
|
||||
assert manager.block_pool[blocks[0].block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
|
||||
|
||||
|
||||
def test_computed_blocks_not_evicted():
|
||||
@@ -413,13 +421,9 @@ def test_cache_blocks():
|
||||
function of KVCacheManager.
|
||||
"""
|
||||
block_size = 4
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
# Req:
|
||||
# Block 0: [0, 1, 2, 3]
|
||||
@@ -430,26 +434,31 @@ def test_cache_blocks():
|
||||
|
||||
# Test that blocks are cached correctly for 2 full blocks from the start.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
block_hashes: List[BlockHashType] = []
|
||||
|
||||
manager._cache_full_blocks(
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blk_start_idx=0,
|
||||
full_blocks=blocks,
|
||||
prev_block=None,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
assert len(manager.cached_block_hash_to_block) == 2
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Test that blocks that don't start from the beginning are cached correctly.
|
||||
blocks = [KVCacheBlock(block_id=2)]
|
||||
manager._cache_full_blocks(
|
||||
blocks += [KVCacheBlock(block_id=2)]
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blk_start_idx=2,
|
||||
full_blocks=blocks,
|
||||
prev_block=None,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=2,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
)
|
||||
assert len(manager.cached_block_hash_to_block) == 3
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
||||
|
||||
@@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
# In this case, the ref_cnt of the computed blocks should not be changed.
|
||||
assert manager.free_block_queue.num_free_blocks == 5
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks == block_part1
|
||||
@@ -621,12 +630,12 @@ def test_reset_prefix_cache():
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
assert manager.cached_block_hash_to_block
|
||||
assert manager.block_pool.cached_block_hash_to_block
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
assert manager.reset_prefix_cache()
|
||||
assert not manager.cached_block_hash_to_block
|
||||
assert all([blk.block_hash is None for blk in manager.block_pool])
|
||||
assert not manager.block_pool.cached_block_hash_to_block
|
||||
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])
|
||||
|
||||
Reference in New Issue
Block a user