[v1] Support multiple KV cache groups in GPU model runner (#17945)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -4,6 +4,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -96,3 +98,48 @@ class BlockTable:
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, kv_cache_config: KVCacheConfig) -> None:
|
||||
max_num_blocks_per_req = [
|
||||
cdiv(max_model_len, g.kv_cache_spec.block_size)
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
self.block_tables = [
|
||||
BlockTable(max_num_reqs, max_num_blocks_per_req[i],
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for i in range(len(kv_cache_config.kv_cache_groups))
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.move_row(src, tgt)
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit(num_reqs)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear()
|
||||
|
||||
def __getitem__(self, idx: int) -> "BlockTable":
|
||||
"""Returns the BlockTable for the i-th KV cache group."""
|
||||
return self.block_tables[idx]
|
||||
|
||||
Reference in New Issue
Block a user