[V1] Implement sliding window attention in kv_cache_manager (#14097)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-04-01 15:33:17 +08:00
committed by GitHub
parent c7e63aa4d8
commit 3a5f0afcd2
15 changed files with 662 additions and 158 deletions

View File

@@ -4,6 +4,7 @@
from typing import Optional
import pytest
import torch
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
@@ -12,6 +13,8 @@ from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
def make_request(request_id,
@@ -39,13 +42,23 @@ def make_request(request_id,
)
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=num_blocks,
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
def test_prefill(hash_algo):
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
caching_hash_algo=hash_algo,
num_preallocate_tokens=16,
@@ -67,12 +80,12 @@ def test_prefill(hash_algo):
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
for block_id in (1, 2, 3):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
@@ -80,7 +93,7 @@ def test_prefill(hash_algo):
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
for block_id in (4, 5):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@@ -90,11 +103,11 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
assert [b.block_id for b in blocks] == [6, 7]
for block in computed_blocks:
assert block.ref_cnt == 2
@@ -107,14 +120,14 @@ def test_prefill(hash_algo):
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
@@ -122,11 +135,11 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]
assert [b.block_id for b in blocks] == [8, 9]
# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@@ -148,7 +161,7 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@@ -162,10 +175,8 @@ def test_prefill_plp():
3. Schedule plp request; no hit should occur; validate blocks
'''
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@@ -186,13 +197,13 @@ def test_prefill_plp():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
req0_block_hashes = [b.block_hash for b in blocks]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
for block_id in (1, 2, 3):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
@@ -200,7 +211,7 @@ def test_prefill_plp():
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
for block_id in (4, 5):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@@ -211,11 +222,11 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
assert [b.block_id for b in blocks] == [6, 7]
for block in computed_blocks:
assert block.ref_cnt == 2
@@ -228,14 +239,14 @@ def test_prefill_plp():
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
@@ -251,7 +262,7 @@ def test_prefill_plp():
block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [0, 1, 2, 3, 4]
assert block_ids != [1, 2, 3, 4, 5]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
@@ -263,10 +274,8 @@ def test_prefill_plp():
def test_decode():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@@ -282,7 +291,7 @@ def test_decode():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@@ -316,10 +325,8 @@ def test_decode():
def test_evict():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@@ -350,15 +357,15 @@ def test_evict():
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
assert [b.block_id for b in computed_blocks] == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert [b.block_id for b in blocks] == [7, 6]
assert manager.block_pool.free_block_queue.num_free_blocks == 6
@@ -369,10 +376,8 @@ def test_hash_block_correct_reuse():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=1,
make_kv_cache_config(16, 2),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@@ -408,10 +413,8 @@ def test_computed_blocks_not_evicted():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=2,
make_kv_cache_config(block_size, 3),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@@ -424,7 +427,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0
assert blocks[0].block_id == 1
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
@@ -433,7 +436,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
assert blocks[0].block_id == 2
# Free the blocks.
manager.free(req0)
@@ -444,13 +447,13 @@ def test_computed_blocks_not_evicted():
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0
assert computed_blocks[0].block_id == 1
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
assert blocks[0].block_id == 2
def test_basic_prefix_caching_disabled():
@@ -459,10 +462,8 @@ def test_basic_prefix_caching_disabled():
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=4,
make_kv_cache_config(block_size, 5),
max_model_len=8192,
sliding_window=None,
enable_caching=False,
num_preallocate_tokens=0,
)
@@ -502,10 +503,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
make_kv_cache_config(block_size, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
@@ -586,10 +585,8 @@ def test_mm_prefix_caching():
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@@ -629,7 +626,7 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@@ -667,10 +664,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
make_kv_cache_config(block_size, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@@ -723,10 +718,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
def test_reset_prefix_cache():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@@ -736,7 +729,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@@ -745,7 +738,7 @@ def test_reset_prefix_cache():
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]
assert [b.block_id for b in blocks] == [5]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()