[v1] Cleanup the BlockTable in InputBatch (#13977)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -15,13 +15,11 @@ class BlockTable:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
@@ -42,18 +40,19 @@ class BlockTable:
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
row_idx: int,
|
||||
start: int,
|
||||
block_ids: List[int],
|
||||
row_idx: int,
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||
|
||||
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
|
||||
self.append_row(row_idx, 0, block_ids)
|
||||
def add_row(self, block_ids: List[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
|
||||
Reference in New Issue
Block a user