[Attention] Refactor attention metadata builder interface (#20466)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -14,12 +14,14 @@ class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = block_size
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -79,10 +81,31 @@ class BlockTable:
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_table_cpu = self.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
self.slot_mapping[:num_tokens].copy_(
|
||||
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
@@ -107,7 +130,8 @@ class MultiGroupBlockTable:
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, block_sizes: list[int]) -> None:
|
||||
self.block_tables = [
|
||||
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
||||
BlockTable(block_size, max_num_reqs, cdiv(max_model_len,
|
||||
block_size),
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
@@ -129,9 +153,18 @@ class MultiGroupBlockTable:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit(num_reqs)
|
||||
block_table.compute_slot_mapping(req_indices, positions)
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_block_table(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_slot_mapping(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
|
||||
Reference in New Issue
Block a user