[v1] Support multiple KV cache groups in GPU model runner (#17945)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-15 09:54:54 +08:00
committed by GitHub
parent f25e0d1125
commit e60f550b38
16 changed files with 482 additions and 215 deletions

View File

@@ -4,6 +4,8 @@ import numpy as np
import torch
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__)
@@ -96,3 +98,48 @@ class BlockTable:
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, kv_cache_config: KVCacheConfig) -> None:
max_num_blocks_per_req = [
cdiv(max_model_len, g.kv_cache_spec.block_size)
for g in kv_cache_config.kv_cache_groups
]
self.block_tables = [
BlockTable(max_num_reqs, max_num_blocks_per_req[i],
max_num_batched_tokens, pin_memory, device)
for i in range(len(kv_cache_config.kv_cache_groups))
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def commit(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit(num_reqs)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]