[v1][Spec Decode] Make sliding window compatible with eagle prefix caching (#17398)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-01 02:25:53 +08:00
committed by GitHub
parent 42d9a2c4c7
commit 81ecf425f0
4 changed files with 96 additions and 23 deletions

View File

@@ -15,7 +15,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
KVCacheGroupSpec, SlidingWindowSpec)
def make_request(request_id,
@@ -863,11 +863,11 @@ def test_eagle_enabled_removes_last_block():
req_eagle = make_request("eagle_divisible", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 2 blocks:
# Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. last_block_hash is not None → Eagle pop is not SKIPPED
# 2. drop last matched block → 1 remaining block
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size # 32 tokens
assert num_tokens == 1 * block_size # 16 tokens
def test_eagle_with_partial_blocks():
@@ -894,3 +894,59 @@ def test_eagle_with_partial_blocks():
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size
def test_eagle_with_sliding_window():
"""Test Eagle behavior with sliding window."""
block_size = 16
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
use_mla=False,
)
manager = KVCacheManager(
KVCacheConfig(
num_blocks=10,
tensors={},
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
),
max_model_len=8192,
enable_caching=True,
use_eagle=True,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert num_tokens == 1 * block_size
# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block) is not None
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix.
assert len(computed_blocks) == 0
assert num_tokens == 0