2023-05-14 22:32:38 -07:00
|
|
|
"""Token blocks."""
|
2024-06-17 15:08:05 -07:00
|
|
|
import weakref
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from typing import Dict, List
|
2023-02-09 11:26:21 +00:00
|
|
|
|
2023-06-17 03:07:40 -07:00
|
|
|
from vllm.utils import Device
|
2023-02-09 11:26:21 +00:00
|
|
|
|
2023-05-14 22:32:38 -07:00
|
|
|
_BLANK_TOKEN_ID = -1
|
2023-02-09 11:26:21 +00:00
|
|
|
|
2024-03-02 03:50:01 -05:00
|
|
|
DEFAULT_LAST_ACCESSED_TIME = -1
|
|
|
|
|
|
2024-06-17 15:08:05 -07:00
|
|
|
TokensBlock = List[int]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlockPool:
|
2024-06-17 20:54:57 -07:00
|
|
|
"""A pool of logical blocks.
|
2024-06-17 15:08:05 -07:00
|
|
|
When requests come, we create a lot of logical blocks;
|
|
|
|
|
when requests are done, we destroy a lot of logical blocks.
|
|
|
|
|
It turns out that creating and destroying logical blocks can be expensive,
|
|
|
|
|
especially for the `token_ids` field, which is a list of integers.
|
|
|
|
|
To avoid this overhead, we use a pool to manage the logical blocks.
|
|
|
|
|
When an old request is done and a new request comes, we can reuse the
|
|
|
|
|
logical blocks from the old request to feed the new request.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
# block size to list of token blocks
|
|
|
|
|
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)
|
|
|
|
|
|
|
|
|
|
def alloc_block(self, block_size: int) -> TokensBlock:
|
|
|
|
|
if block_size in self.pool and self.pool[block_size]:
|
|
|
|
|
return self.pool[block_size].pop()
|
|
|
|
|
return [_BLANK_TOKEN_ID] * block_size
|
|
|
|
|
|
|
|
|
|
def del_block(self, block: TokensBlock) -> None:
|
|
|
|
|
self.pool[len(block)].append(block)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_BLOCK_POOL = BlockPool()
|
|
|
|
|
|
2023-02-09 11:26:21 +00:00
|
|
|
|
|
|
|
|
class LogicalTokenBlock:
|
2023-05-14 22:32:38 -07:00
|
|
|
"""A block that stores a contiguous chunk of tokens from left to right.
|
|
|
|
|
|
|
|
|
|
Logical blocks are used to represent the states of the corresponding
|
|
|
|
|
physical blocks in the KV cache.
|
|
|
|
|
"""
|
2023-02-09 11:26:21 +00:00
|
|
|
|
2023-02-14 01:19:05 +00:00
|
|
|
def __init__(
|
2023-02-09 11:26:21 +00:00
|
|
|
self,
|
|
|
|
|
block_number: int,
|
|
|
|
|
block_size: int,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.block_number = block_number
|
|
|
|
|
self.block_size = block_size
|
|
|
|
|
|
2024-06-17 15:08:05 -07:00
|
|
|
self.token_ids = _BLOCK_POOL.alloc_block(block_size)
|
|
|
|
|
# this finalizer is used to return the block to the pool when the object is deleted # noqa
|
|
|
|
|
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
|
|
|
|
|
# i.e. `self.token_ids` may be deleted before `self`, and we lose
|
|
|
|
|
# the opportunity to return the block to the pool
|
|
|
|
|
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
|
|
|
|
|
self.token_ids)
|
2023-02-09 11:26:21 +00:00
|
|
|
self.num_tokens = 0
|
|
|
|
|
|
|
|
|
|
def is_empty(self) -> bool:
|
|
|
|
|
return self.num_tokens == 0
|
|
|
|
|
|
|
|
|
|
def get_num_empty_slots(self) -> int:
|
|
|
|
|
return self.block_size - self.num_tokens
|
|
|
|
|
|
|
|
|
|
def is_full(self) -> bool:
|
|
|
|
|
return self.num_tokens == self.block_size
|
|
|
|
|
|
2023-05-10 00:58:31 -07:00
|
|
|
def append_tokens(self, token_ids: List[int]) -> None:
|
2023-02-09 11:26:21 +00:00
|
|
|
assert len(token_ids) <= self.get_num_empty_slots()
|
2023-07-03 11:31:55 -07:00
|
|
|
curr_idx = self.num_tokens
|
|
|
|
|
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
|
2023-02-09 11:26:21 +00:00
|
|
|
self.num_tokens += len(token_ids)
|
|
|
|
|
|
|
|
|
|
def get_token_ids(self) -> List[int]:
|
|
|
|
|
return self.token_ids[:self.num_tokens]
|
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
def get_last_token_id(self) -> int:
|
|
|
|
|
assert self.num_tokens > 0
|
|
|
|
|
return self.token_ids[self.num_tokens - 1]
|
|
|
|
|
|
2023-02-09 11:26:21 +00:00
|
|
|
|
|
|
|
|
class PhysicalTokenBlock:
|
2023-05-14 22:32:38 -07:00
|
|
|
"""Represents the state of a block in the KV cache."""
|
2023-02-09 11:26:21 +00:00
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
device: Device,
|
|
|
|
|
block_number: int,
|
|
|
|
|
block_size: int,
|
2024-03-02 03:50:01 -05:00
|
|
|
block_hash: int,
|
|
|
|
|
num_hashed_tokens: int,
|
2023-02-09 11:26:21 +00:00
|
|
|
) -> None:
|
|
|
|
|
self.device = device
|
|
|
|
|
self.block_number = block_number
|
|
|
|
|
self.block_size = block_size
|
2024-03-02 03:50:01 -05:00
|
|
|
self.block_hash = block_hash
|
|
|
|
|
self.num_hashed_tokens = num_hashed_tokens
|
2023-02-09 11:26:21 +00:00
|
|
|
|
|
|
|
|
self.ref_count = 0
|
2024-03-02 03:50:01 -05:00
|
|
|
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
|
|
|
|
|
|
|
|
|
|
self.computed = False
|
2023-02-14 09:34:07 +00:00
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return (f'PhysicalTokenBlock(device={self.device}, '
|
|
|
|
|
f'block_number={self.block_number}, '
|
2024-03-02 03:50:01 -05:00
|
|
|
f'num_hashed_tokens={self.num_hashed_tokens}, '
|
|
|
|
|
f'ref_count={self.ref_count}, '
|
|
|
|
|
f'last_accessed={self.last_accessed}, '
|
|
|
|
|
f'computed={self.computed})')
|
2024-01-17 16:32:10 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Mapping: logical block number -> physical block.
|
|
|
|
|
BlockTable = List[PhysicalTokenBlock]
|