[Core] Introduce popleft_n and append_n in FreeKVCacheBlockQueue to further optimize block_pool (#21222)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
@@ -214,21 +214,18 @@ class BlockPool:
|
||||
raise ValueError(
|
||||
f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
ret: list[KVCacheBlock] = []
|
||||
idx = 0
|
||||
while idx < num_blocks:
|
||||
# First allocate blocks.
|
||||
curr_block = self.free_block_queue.popleft()
|
||||
assert curr_block.ref_cnt == 0
|
||||
|
||||
# If the block is cached, evict it.
|
||||
if self.enable_caching:
|
||||
self._maybe_evict_cached_block(curr_block)
|
||||
|
||||
curr_block.incr_ref()
|
||||
ret.append(curr_block)
|
||||
idx += 1
|
||||
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
|
||||
|
||||
# In order to only iterate the list once, we duplicated code a bit
|
||||
if self.enable_caching:
|
||||
for block in ret:
|
||||
self._maybe_evict_cached_block(block)
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
else:
|
||||
for block in ret:
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
return ret
|
||||
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
@@ -289,11 +286,14 @@ class BlockPool:
|
||||
ordered_blocks: A list of blocks to free ordered by their eviction
|
||||
priority.
|
||||
"""
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
# null_block should not be added to the free list.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.append(block)
|
||||
# Materialize the iterable to allow multiple passes.
|
||||
blocks_list = list(ordered_blocks)
|
||||
for block in blocks_list:
|
||||
block.ref_cnt -= 1
|
||||
self.free_block_queue.append_n([
|
||||
block for block in blocks_list
|
||||
if block.ref_cnt == 0 and not block.is_null
|
||||
])
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
|
||||
@@ -154,6 +154,8 @@ class KVCacheBlock:
|
||||
# Whether the block is a null block that should never be cached.
|
||||
is_null: bool = False
|
||||
|
||||
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
|
||||
# avoid function calls.
|
||||
def incr_ref(self):
|
||||
self.ref_cnt += 1
|
||||
|
||||
@@ -273,6 +275,39 @@ class FreeKVCacheBlockQueue:
|
||||
self.num_free_blocks -= 1
|
||||
return first_block
|
||||
|
||||
def popleft_n(self, n: int) -> list[KVCacheBlock]:
|
||||
"""Pop the first n free blocks and reduce num_free_blocks by n.
|
||||
|
||||
Args:
|
||||
n: The number of blocks to pop.
|
||||
|
||||
Returns:
|
||||
A list of n free blocks.
|
||||
"""
|
||||
if n == 0:
|
||||
return []
|
||||
assert self.num_free_blocks >= n
|
||||
self.num_free_blocks -= n
|
||||
|
||||
curr_block = self.fake_free_list_head.next_free_block
|
||||
# Pop n blocks from the head of the list
|
||||
ret = []
|
||||
for _ in range(n):
|
||||
assert curr_block is not None
|
||||
ret.append(curr_block)
|
||||
last_block = curr_block
|
||||
curr_block = curr_block.next_free_block
|
||||
# Reset prev_free_block and next_free_block of all popped blocks
|
||||
last_block.prev_free_block = None
|
||||
last_block.next_free_block = None
|
||||
|
||||
if curr_block is not None:
|
||||
# The queue is not empty, connect the fake head to
|
||||
# the new first block.
|
||||
self.fake_free_list_head.next_free_block = curr_block
|
||||
curr_block.prev_free_block = self.fake_free_list_head
|
||||
return ret
|
||||
|
||||
def remove(self, block: KVCacheBlock) -> None:
|
||||
"""Remove a block in the free list and reduce num_free_blocks by 1.
|
||||
|
||||
@@ -315,6 +350,29 @@ class FreeKVCacheBlockQueue:
|
||||
|
||||
self.num_free_blocks += 1
|
||||
|
||||
def append_n(self, blocks: list[KVCacheBlock]) -> None:
|
||||
"""Put a list of blocks back into the free list
|
||||
|
||||
Args:
|
||||
blocks: The blocks to append.
|
||||
"""
|
||||
if len(blocks) == 0:
|
||||
return
|
||||
self.num_free_blocks += len(blocks)
|
||||
|
||||
last_block = self.fake_free_list_tail.prev_free_block
|
||||
assert last_block is not None, (
|
||||
"prev_free_block of fake_free_list_tail should always exist")
|
||||
# Add inter-connections between consecutive blocks
|
||||
for block in blocks:
|
||||
block.prev_free_block = last_block
|
||||
last_block.next_free_block = block
|
||||
last_block = block
|
||||
|
||||
# Connect the last block of <blocks> to the fake tail
|
||||
last_block.next_free_block = self.fake_free_list_tail
|
||||
self.fake_free_list_tail.prev_free_block = last_block
|
||||
|
||||
def get_all_free_blocks(self) -> list[KVCacheBlock]:
|
||||
"""Get all free blocks in the free list. Mainly used for testing.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user