[v1] Cleanup the BlockTable in InputBatch (#13977)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-03-01 03:03:16 +08:00
committed by GitHub
parent c3b6559a10
commit e7bd944e08
5 changed files with 25 additions and 17 deletions

View File

@@ -15,13 +15,11 @@ class BlockTable:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device
@@ -42,18 +40,19 @@ class BlockTable:
def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
row_idx: int,
) -> None:
if not block_ids:
return
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
self.append_row(row_idx, 0, block_ids)
def add_row(self, block_ids: List[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]