[Core] Introduce popleft_n and append_n in FreeKVCacheBlockQueue to further optimize block_pool (#21222)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang
2025-07-22 06:17:47 -07:00
committed by GitHub
parent 10904e6d75
commit ed25054577
3 changed files with 182 additions and 19 deletions

View File

@@ -214,21 +214,18 @@ class BlockPool:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: list[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
# In order to only iterate the list once, we duplicated code a bit
if self.enable_caching:
for block in ret:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
@@ -289,11 +286,14 @@ class BlockPool:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
for block in ordered_blocks:
block.decr_ref()
# null_block should not be added to the free list.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.append(block)
# Materialize the iterable to allow multiple passes.
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
self.free_block_queue.append_n([
block for block in blocks_list
if block.ref_cnt == 0 and not block.is_null
])
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF

View File

@@ -154,6 +154,8 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached.
is_null: bool = False
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
# avoid function calls.
def incr_ref(self):
self.ref_cnt += 1
@@ -273,6 +275,39 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks -= 1
return first_block
def popleft_n(self, n: int) -> list[KVCacheBlock]:
"""Pop the first n free blocks and reduce num_free_blocks by n.
Args:
n: The number of blocks to pop.
Returns:
A list of n free blocks.
"""
if n == 0:
return []
assert self.num_free_blocks >= n
self.num_free_blocks -= n
curr_block = self.fake_free_list_head.next_free_block
# Pop n blocks from the head of the list
ret = []
for _ in range(n):
assert curr_block is not None
ret.append(curr_block)
last_block = curr_block
curr_block = curr_block.next_free_block
# Reset prev_free_block and next_free_block of all popped blocks
last_block.prev_free_block = None
last_block.next_free_block = None
if curr_block is not None:
# The queue is not empty, connect the fake head to
# the new first block.
self.fake_free_list_head.next_free_block = curr_block
curr_block.prev_free_block = self.fake_free_list_head
return ret
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
@@ -315,6 +350,29 @@ class FreeKVCacheBlockQueue:
self.num_free_blocks += 1
def append_n(self, blocks: list[KVCacheBlock]) -> None:
"""Put a list of blocks back into the free list
Args:
blocks: The blocks to append.
"""
if len(blocks) == 0:
return
self.num_free_blocks += len(blocks)
last_block = self.fake_free_list_tail.prev_free_block
assert last_block is not None, (
"prev_free_block of fake_free_list_tail should always exist")
# Add inter-connections between consecutive blocks
for block in blocks:
block.prev_free_block = last_block
last_block.next_free_block = block
last_block = block
# Connect the last block of <blocks> to the fake tail
last_block.next_free_block = self.fake_free_list_tail
self.fake_free_list_tail.prev_free_block = last_block
def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.