[Bugfix][fast] Fix the get_num_blocks_touched logic (#6849)
This commit is contained in:
@@ -100,3 +100,45 @@ class TestNaiveBlockAllocator:
|
||||
for i, block in enumerate(blocks):
|
||||
assert allocator.get_num_free_blocks() == i
|
||||
allocator.free(block)
|
||||
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [4])
|
||||
@pytest.mark.parametrize("block_size", [8])
|
||||
def test_naive_block_get_num_blocks_touched(num_blocks, block_size):
|
||||
""" Verify the allocator can correctly return the number of
|
||||
blocks touched, with different lookahead slots.
|
||||
"""
|
||||
allocator_src = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
allocator_dst = NaiveBlockAllocator(create_block=NaiveBlock,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
|
||||
# Create a chain of cacheable blocks in the dst
|
||||
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
"immutable",
|
||||
allocator_src,
|
||||
prev_block=None,
|
||||
token_ids=list(range(block_size)))
|
||||
src_blocks = [allocate_block() for _ in range(num_blocks - 1)]
|
||||
|
||||
# All blocks are cached
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks) == num_blocks - 1
|
||||
|
||||
# Insert one non-full block in the src
|
||||
allocate_non_full_block = \
|
||||
TestNaiveBlockAllocator.create_allocate_lambda(
|
||||
"mutable", allocator_src,
|
||||
prev_block=src_blocks[-1],token_ids=[]
|
||||
)
|
||||
src_blocks.append(allocate_non_full_block())
|
||||
src_blocks[-1].append_token_ids([0])
|
||||
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks, num_lookahead_slots=1) == num_blocks
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks, num_lookahead_slots=block_size - 1) == num_blocks
|
||||
assert allocator_dst.get_num_blocks_touched(
|
||||
src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1)
|
||||
|
||||
Reference in New Issue
Block a user