[v1] Hybrid Memory Allocator (#17996)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-06-06 11:47:09 +08:00
committed by GitHub
parent 3465b87ef8
commit f8a1a2d108
21 changed files with 1605 additions and 440 deletions

View File

@@ -4,7 +4,8 @@
import torch
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock)
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec
@@ -12,9 +13,8 @@ from vllm.v1.kv_cache_interface import SlidingWindowSpec
def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False,
num_kv_cache_groups=1,
caching_hash_fn=lambda x: x)
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)
def test_sliding_window_possible_cached_prefix():
@@ -42,13 +42,18 @@ def test_sliding_window_possible_cached_prefix():
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[block_hash] = {
i: block_pool.blocks[i + 10]
}
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
computed_blocks = manager.find_longest_cache_hit(
block_hash_list,
len(block_hash_list) * block_size)
block_hashes=block_hash_list,
max_length=len(block_hash_list) * block_size,
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=sliding_window_spec,
use_eagle=False)[0]
assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block
@@ -95,13 +100,13 @@ def test_sliding_window_remove_skipped_blocks():
null_block_id = block_pool.null_block.block_id
def id_to_block_table(ids):
def id_to_block_table(ids) -> list[KVCacheBlock]:
return [
KVCacheBlock(id_)
if id_ != null_block_id else block_pool.null_block for id_ in ids
]
def assert_block_id(block_table, ids):
def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
for block, id_ in zip(block_table, ids):
if id_ == null_block_id:
assert block == block_pool.null_block