[Core] Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue (#21005)

Signed-off-by: Jialin Ouyang <jialino@meta.com>
This commit is contained in:
JialinOuyang-Meta
2025-07-18 12:34:40 -07:00
committed by GitHub
parent b2eb2b5ad7
commit 0f199f197b
4 changed files with 209 additions and 57 deletions

View File

@@ -212,27 +212,65 @@ class FreeKVCacheBlockQueue:
def __init__(self, blocks: list[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks.
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
# Initialize doubly links of consecutive blocks
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
if i < self.num_free_blocks - 1:
blocks[i].next_free_block = blocks[i + 1]
# Create a fake head and a tail block for the doubly linked list to
# reduce branching in the code
#
# The implementation garenteed that the fake head and tail
# are NEVER got popped, so we could safely assume each real blocks
# in the queue has prev and next blocks.
self.fake_free_list_head = KVCacheBlock(block_id=-1)
self.fake_free_list_tail = KVCacheBlock(block_id=-1)
if self.num_free_blocks > 0:
# Connect fake_head and fake_tail to the first and last block
# respectively.
self.fake_free_list_head.next_free_block = blocks[0]
blocks[0].prev_free_block = self.fake_free_list_head
self.fake_free_list_tail.prev_free_block = blocks[-1]
blocks[-1].next_free_block = self.fake_free_list_tail
else:
# For empty list, simply connect the fake head and tail.
self.fake_free_list_head.next_free_block = self.fake_free_list_tail
self.fake_free_list_tail.prev_free_block = self.fake_free_list_head
def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
if not self.free_list_head:
if (self.fake_free_list_head.next_free_block
is self.fake_free_list_tail
or self.fake_free_list_head.next_free_block is None):
assert self.num_free_blocks == 0, (
f"num_free_blocks ({self.num_free_blocks}) is out of sync "
"with the free list.")
raise ValueError("No free blocks available")
block = self.free_list_head
self.remove(block)
return block
first_block: KVCacheBlock = self.fake_free_list_head.next_free_block
if first_block.next_free_block is None:
# This should not happen if the block is from the free list.
# It indicates a bug in the caller's logic.
raise RuntimeError("Invalid block found in popleft() "
"which doesn't have a valid next_free_block")
# Connect fake_head and the next block of first_block (i.e. second block
# or fake tail).
self.fake_free_list_head.next_free_block = first_block.next_free_block
first_block.next_free_block.prev_free_block = self.fake_free_list_head
# Remove the block from the linked list.
first_block.prev_free_block = first_block.next_free_block = None
self.num_free_blocks -= 1
return first_block
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
@@ -240,19 +278,15 @@ class FreeKVCacheBlockQueue:
Args:
block: The block to remove.
"""
if block.prev_free_block is not None:
# Link the previous block to the next block.
block.prev_free_block.next_free_block = block.next_free_block
if block.next_free_block is not None:
# Link the next block to the previous block.
block.next_free_block.prev_free_block = block.prev_free_block
if block.prev_free_block is None or block.next_free_block is None:
# This should not happen if the block is from the free list.
# It indicates a bug in the caller's logic.
raise RuntimeError(f"remove() called on an invalid block: {block}")
if block == self.free_list_head:
# Update the head if the block is the head.
self.free_list_head = block.next_free_block
if block == self.free_list_tail:
# Update the tail if the block is the tail.
self.free_list_tail = block.prev_free_block
# Link the previous block to the next block.
block.prev_free_block.next_free_block = block.next_free_block
# Link the next block to the previous block.
block.next_free_block.prev_free_block = block.prev_free_block
# Remove the block from the linked list.
block.prev_free_block = block.next_free_block = None
@@ -265,17 +299,19 @@ class FreeKVCacheBlockQueue:
Args:
block: The block to append.
"""
if self.free_list_tail is not None:
# Link the last block to the new block.
self.free_list_tail.next_free_block = block
block.prev_free_block = self.free_list_tail
self.free_list_tail = block
else:
# The free list is empty.
assert self.free_list_head is None
self.free_list_head = self.free_list_tail = block
if self.fake_free_list_tail.prev_free_block is None:
raise RuntimeError(
"prev_free_block of fake_free_list_tail should always exist")
last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block
# Connect the new block after the last block.
last_block.next_free_block = block
block.prev_free_block = last_block
# Connect the fake tail after the new block.
block.next_free_block = self.fake_free_list_tail
self.fake_free_list_tail.prev_free_block = block
block.next_free_block = None
self.num_free_blocks += 1
def get_all_free_blocks(self) -> list[KVCacheBlock]:
@@ -285,8 +321,14 @@ class FreeKVCacheBlockQueue:
A list of free blocks.
"""
ret = []
curr_block = self.free_list_head
while curr_block is not None:
if self.fake_free_list_head.next_free_block is None:
raise RuntimeError(
"next_free_block of fake_free_list_head should always exist")
# Start from the first block
curr_block: KVCacheBlock = self.fake_free_list_head.next_free_block
# As long as next_free_block is available, we haven't reached to
# the fake tail yet.
while curr_block.next_free_block is not None:
ret.append(curr_block)
curr_block = curr_block.next_free_block
return ret