[Hybrid Allocator] Support KV cache groups with different block_size (#29143)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Yifan Qiao
2025-11-25 07:30:57 -08:00
committed by GitHub
parent e502098643
commit 48ddb02b79
11 changed files with 472 additions and 87 deletions

View File

@@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
BlockHashWithGroupId,
ExternalBlockHash,
FreeKVCacheBlockQueue,
@@ -133,6 +135,10 @@ class BlockPool:
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
hash_block_size: The block size of which the block hashes are computed.
The actual block size usually equals hash_block_size, but in cases
where different KV cache groups have different block sizes, the
actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events.
"""
@@ -140,11 +146,13 @@ class BlockPool:
self,
num_gpu_blocks: int,
enable_caching: bool,
hash_block_size: int,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
self.hash_block_size = hash_block_size
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
@@ -223,8 +231,20 @@ class BlockPool:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:]
if block_size == self.hash_block_size:
# Common case.
block_hashes: BlockHashList = request.block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when
# different KV cache groups have different block sizes.
assert block_size % self.hash_block_size == 0
# Recalculate block_hashes at the granularity of block_size, using
# the original block_hashes (at the granularity of hash_block_size).
block_hashes = BlockHashListWithBlockSize(
request.block_hashes, self.hash_block_size, block_size
)
new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None
)

View File

@@ -2,15 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from math import lcm
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
KVCacheBlock,
)
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager,
FullAttentionManager,
get_manager_for_kv_cache_spec,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.request import Request
@@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.block_pool = BlockPool(
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
kv_cache_config.num_blocks,
enable_caching,
hash_block_size,
enable_kv_cache_events,
)
# Needs special handling for find_longest_cache_hit if eagle is enabled
@@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
self.num_single_type_manager = len(self.single_type_managers)
@@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
@@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
# For models using only Mamba, block_size is set to max_model_len when
# prefix caching is disabled, and hash_block_size validation is skipped.
assert not enable_caching or (hash_block_size == self.block_size), (
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
)
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
@@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
@@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
super().__init__(
kv_cache_config,
@@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
# hash_block_size: the block size used to compute block hashes.
# The actual block size usually equals hash_block_size, but in cases where
# different KV cache groups have different block sizes, the actual block size
# can be a multiple of hash_block_size.
self.hash_block_size = hash_block_size
assert all(
g.kv_cache_spec.block_size % hash_block_size == 0
for g in kv_cache_config.kv_cache_groups
), "block_size must be divisible by hash_block_size"
assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
@@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
if self.enable_caching:
# this requirement is only needed for the prefix caching logic
divisible = self.other_block_size % self.full_attention_block_size
assert divisible == 0, (
"KVCacheCoordinator assumes the block_size of full "
"attention layers is divisible by other layers now."
)
# The LCM of the block sizes of full attention and other attention.
# The cache hit length must be a multiple of the LCM of the block sizes
# to make sure the cache hit length is a multiple of the block size of
# each attention type. Requiring this because we don't support partial
# block cache hit yet.
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
@@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
if self.full_attention_spec.block_size == self.hash_block_size:
# Common case.
full_attention_block_hashes: BlockHashList = block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when different
# KV cache groups have different block sizes. In this case, we need to
# recalculate block_hashes at the granularity of block_size, using the
# original block_hashes (at the granularity of hash_block_size).
full_attention_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
)
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes,
block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
if self.other_spec.block_size == self.hash_block_size:
# Common case.
other_block_hashes: BlockHashList = block_hashes
else:
# Similar to the full attention case, here we need to recalculate
# block_hashes at the granularity of block_size, using the original
# block_hashes (at the granularity of hash_block_size).
other_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.other_spec.block_size
)
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes,
block_hashes=other_block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
@@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
max_model_len,
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
dcp_world_size,
pcp_world_size,
hash_block_size,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
dcp_world_size,
pcp_world_size,
hash_block_size,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
dcp_world_size,
pcp_world_size,
hash_block_size,
)

View File

@@ -95,6 +95,7 @@ class KVCacheManager:
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
hash_block_size: int,
enable_caching: bool = True,
use_eagle: bool = False,
log_stats: bool = False,
@@ -107,28 +108,11 @@ class KVCacheManager:
self.enable_caching = enable_caching
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
# FIXME: make prefix cache stats conditional on log_stats. We still need
# this comment because when the log stats is enabled there are still
# potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.block_size: int | None = None
if self.enable_caching:
assert (
len(
set(
g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups
)
)
== 1
), "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0
].kv_cache_spec.block_size
if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
self.block_size *= dcp_world_size * pcp_world_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
@@ -137,6 +121,7 @@ class KVCacheManager:
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool

View File

@@ -5,9 +5,9 @@
import copy
import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from typing import Any, NewType, TypeAlias
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from typing import Any, NewType, TypeAlias, overload
from vllm import envs
from vllm.config import VllmConfig
@@ -825,11 +825,11 @@ def get_num_blocks(
return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
assert len(page_sizes) == 1
return page_sizes.pop()
@@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
return len(page_sizes) == 1
def unify_kv_cache_spec_page_size(
kv_cache_spec: dict[str, KVCacheSpec],
) -> dict[str, KVCacheSpec]:
"""
Unify the page size of the given KVCacheSpec. If the page size of all layers
are the same, return the original KVCacheSpec. If not same, unify the page
size by increasing the block size of layers with smaller page size. Raise
NotImplementedError if failed to unify the page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The updated KVCacheSpec with the same page_size_bytes.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
if len(page_sizes) <= 1:
# All layers have the same page size, no need to unify.
return kv_cache_spec
max_page_size = max(page_sizes)
new_kv_cache_spec = {}
for layer_name, layer_spec in kv_cache_spec.items():
if layer_spec.page_size_bytes == max_page_size:
new_kv_cache_spec[layer_name] = layer_spec
else:
layer_page_size = layer_spec.page_size_bytes
if max_page_size % layer_page_size != 0:
raise NotImplementedError(
"The page size of the layer is not divisible by the "
"maximum page size. Cannot unify by adjusting block_size."
)
ratio = max_page_size // layer_page_size
new_block_size = layer_spec.block_size * ratio
new_spec = replace(layer_spec, block_size=new_block_size)
assert new_spec.page_size_bytes == max_page_size
new_kv_cache_spec[layer_name] = new_spec
return new_kv_cache_spec
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
# kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec
@@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
def get_kv_cache_config_from_groups(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
@@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups(
Args:
vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups
kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes
Returns:
The generated KVCacheConfig
@@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups(
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs)
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(
vllm_config, group_size, available_memory, page_size
@@ -1166,7 +1206,8 @@ def get_kv_cache_groups(
# This returns an empty list to allow for the KVCacheManager to handle
# attention free models.
return []
elif is_kv_cache_spec_uniform(kv_cache_spec):
if is_kv_cache_spec_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
@@ -1176,14 +1217,16 @@ def get_kv_cache_groups(
# full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
raise NotImplementedError
# As KVCacheManager can only allocate memory of one size, we need to unify
# the page size of the layers. For cases cannot be unified, this function
# will raise an error.
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
def generate_scheduler_kv_cache_config(
@@ -1327,10 +1370,7 @@ def get_kv_cache_configs(
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
kv_cache_configs.append(
get_kv_cache_config_from_groups(
vllm_config,
kv_cache_groups_one_worker,
kv_cache_spec_one_worker,
available_memory_one_worker,
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
)
)
@@ -1353,3 +1393,79 @@ def get_kv_cache_configs(
_report_kv_cache_config(vllm_config, kv_cache_config)
return kv_cache_configs
class BlockHashListWithBlockSize:
"""
Convert block-hash granularity from `hash_block_size` to `target_block_size`.
Used when KV cache groups have different block sizes: `hash_block_size`
is the size used to compute the original `block_hashes`; `target_block_size`
is the group's actual block size.
Currently, only scaling up by an integer factor is supported (i.e.,
`target_block_size` is a multiple of `hash_block_size`). Conversion is
performed lazily on access for efficiency, by concatenating consecutive
hashes at `hash_block_size` to form each hash at `target_block_size`.
Example (`hash_block_size` = 16, `target_block_size` = 32):
concatenating two 16-size hashes yields one 32-size hash:
Block hashes with block_size 16:
| Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
|-------------|------|-------|-------|-------|
| Hash | A | B | C | D |
Block hashes with block_size 32:
| Token Range | 0-31 | 32-63 |
|-------------|------|-------|
| Hash | AB | CD |
Args:
block_hashes: Block hashes to convert, computed at `hash_block_size`.
hash_block_size: Block size at which `block_hashes` were computed.
target_block_size: Desired block size; must be a multiple of `hash_block_size`.
"""
def __init__(
self,
block_hashes: list[BlockHash],
hash_block_size: int,
target_block_size: int,
):
self.block_hashes = block_hashes
assert target_block_size % hash_block_size == 0
self.scale_factor = target_block_size // hash_block_size
def __len__(self) -> int:
return len(self.block_hashes) // self.scale_factor
@overload
def __getitem__(self, idx: int) -> BlockHash: ...
@overload
def __getitem__(self, idx: slice) -> list[BlockHash]: ...
def __getitem__(self, idx):
if isinstance(idx, int):
return self._get_value_at(idx)
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
return [self._get_value_at(i) for i in range(start, stop, step)]
raise TypeError(f"Invalid index type: {type(idx)!r}")
def __iter__(self) -> Iterator[BlockHash]:
for i in range(len(self)):
yield self._get_value_at(i)
def _get_value_at(self, idx: int) -> BlockHash:
base = idx * self.scale_factor
end = base + self.scale_factor
merged_hash: bytes = self.block_hashes[base]
for i in range(base + 1, end):
merged_hash += self.block_hashes[i]
return BlockHash(merged_hash)
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize

View File

@@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
hash_block_size=self.block_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

View File

@@ -7,7 +7,7 @@ from collections.abc import Sequence
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
@@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
@abstractmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens). By default, it should
be set to the block_size.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
Returns:
A list of cached blocks with skipped blocks replaced by null block
@@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
@@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
else:
break
if use_eagle and computed_blocks[0]:
# Need to drop the last matched block if eagle is enabled.
for computed in computed_blocks:
computed.pop()
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
return computed_blocks
@@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
@@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
# Skip prefix matching check if the block is not aligned with
# `alignment_tokens`.
if (
num_contiguous_blocks == 0
and block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
# Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
@@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
if use_eagle and computed_blocks[0]:
assert kv_cache_spec.block_size == alignment_tokens, (
"aligned_length is not compatible with eagle now"
)
for computed in computed_blocks:
computed.pop()
return computed_blocks
@@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens).
Returns:
A list of cached blocks
@@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
assert kv_cache_spec.block_size == alignment_tokens, (
"KV cache groups with different block sizes are not compatible with "
"chunked local attention now"
)
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
@@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids))
)
max_num_blocks = max_length // kv_cache_spec.block_size
block_size = kv_cache_spec.block_size
max_num_blocks = max_length // block_size
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
# When enable Mamba prefix caching, `block_size` will be aligned
# across full attention layers and Mamba layers to ensure the
# prefix hit length aligned at block
if (
block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0])
@@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: