[Bugfix][fast] Fix the get_num_blocks_touched logic (#6849)

This commit is contained in:
Zach Zheng
2024-08-08 10:43:30 -07:00
committed by GitHub
parent 21b9c49aa3
commit 782e53ab59
6 changed files with 172 additions and 10 deletions

View File

@@ -100,3 +100,45 @@ class TestNaiveBlockAllocator:
for i, block in enumerate(blocks):
assert allocator.get_num_free_blocks() == i
allocator.free(block)
@staticmethod
@pytest.mark.parametrize("num_blocks", [4])
@pytest.mark.parametrize("block_size", [8])
def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
""" Verify the allocator can correctly return the number of
blocks touched, with different lookahead slots.
"""
allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
block_size=block_size)
allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
block_size=block_size)
# Create a chain of cacheable blocks in the dst
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
"immutable",
allocator_src,
prev_block=None,
token_ids=list(range(block_size)))
src_blocks = [allocate_block() for _ in range(num_blocks - 1)]
# All blocks are cached
assert allocator_dst.get_num_blocks_touched(
src_blocks) == num_blocks - 1
# Insert one non-full block in the src
allocate_non_full_block = \
TestNaiveBlockAllocator.create_allocate_lambda(
"mutable", allocator_src,
prev_block=src_blocks[-1],token_ids=[]
)
src_blocks.append(allocate_non_full_block())
src_blocks[-1].append_token_ids([0])
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=1) == num_blocks
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=block_size - 1) == num_blocks
assert allocator_dst.get_num_blocks_touched(
src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1)