[Hybrid Allocator] Support KV cache groups with different block_size (#29143)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
BlockHashList,
|
||||
BlockHashListWithBlockSize,
|
||||
BlockHashWithGroupId,
|
||||
ExternalBlockHash,
|
||||
FreeKVCacheBlockQueue,
|
||||
@@ -133,6 +135,10 @@ class BlockPool:
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
hash_block_size: The block size of which the block hashes are computed.
|
||||
The actual block size usually equals hash_block_size, but in cases
|
||||
where different KV cache groups have different block sizes, the
|
||||
actual block size can be a multiple of hash_block_size.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
"""
|
||||
|
||||
@@ -140,11 +146,13 @@ class BlockPool:
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
enable_caching: bool,
|
||||
hash_block_size: int,
|
||||
enable_kv_cache_events: bool = False,
|
||||
):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
self.hash_block_size = hash_block_size
|
||||
# All kv-cache blocks.
|
||||
self.blocks: list[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
@@ -223,8 +231,20 @@ class BlockPool:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
||||
if block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
block_hashes: BlockHashList = request.block_hashes
|
||||
else:
|
||||
# block_size is a multiple of hash_block_size. This happens when
|
||||
# different KV cache groups have different block sizes.
|
||||
assert block_size % self.hash_block_size == 0
|
||||
# Recalculate block_hashes at the granularity of block_size, using
|
||||
# the original block_hashes (at the granularity of hash_block_size).
|
||||
block_hashes = BlockHashListWithBlockSize(
|
||||
request.block_hashes, self.hash_block_size, block_size
|
||||
)
|
||||
|
||||
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||
new_hashes: list[ExternalBlockHash] | None = (
|
||||
[] if self.enable_kv_cache_events else None
|
||||
)
|
||||
|
||||
@@ -2,15 +2,25 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from math import lcm
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
BlockHashList,
|
||||
BlockHashListWithBlockSize,
|
||||
KVCacheBlock,
|
||||
)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
CrossAttentionManager,
|
||||
FullAttentionManager,
|
||||
get_manager_for_kv_cache_spec,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.block_pool = BlockPool(
|
||||
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
|
||||
kv_cache_config.num_blocks,
|
||||
enable_caching,
|
||||
hash_block_size,
|
||||
enable_kv_cache_events,
|
||||
)
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
@@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
@@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
@@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
self.block_size *= dcp_world_size
|
||||
if pcp_world_size > 1:
|
||||
self.block_size *= pcp_world_size
|
||||
# For models using only Mamba, block_size is set to max_model_len when
|
||||
# prefix caching is disabled, and hash_block_size validation is skipped.
|
||||
assert not enable_caching or (hash_block_size == self.block_size), (
|
||||
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
|
||||
)
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||
)
|
||||
@@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.block_size,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
@@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
# hash_block_size: the block size used to compute block hashes.
|
||||
# The actual block size usually equals hash_block_size, but in cases where
|
||||
# different KV cache groups have different block sizes, the actual block size
|
||||
# can be a multiple of hash_block_size.
|
||||
self.hash_block_size = hash_block_size
|
||||
assert all(
|
||||
g.kv_cache_spec.block_size % hash_block_size == 0
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
), "block_size must be divisible by hash_block_size"
|
||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
@@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
self.other_spec = other_spec
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
|
||||
if self.enable_caching:
|
||||
# this requirement is only needed for the prefix caching logic
|
||||
divisible = self.other_block_size % self.full_attention_block_size
|
||||
assert divisible == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full "
|
||||
"attention layers is divisible by other layers now."
|
||||
)
|
||||
# The LCM of the block sizes of full attention and other attention.
|
||||
# The cache hit length must be a multiple of the LCM of the block sizes
|
||||
# to make sure the cache hit length is a multiple of the block size of
|
||||
# each attention type. Requiring this because we don't support partial
|
||||
# block cache hit yet.
|
||||
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
@@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
- The number of tokens of the longest cache hit.
|
||||
"""
|
||||
# First, find the longest cache hit for full attention.
|
||||
if self.full_attention_spec.block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
full_attention_block_hashes: BlockHashList = block_hashes
|
||||
else:
|
||||
# block_size is a multiple of hash_block_size. This happens when different
|
||||
# KV cache groups have different block sizes. In this case, we need to
|
||||
# recalculate block_hashes at the granularity of block_size, using the
|
||||
# original block_hashes (at the granularity of hash_block_size).
|
||||
full_attention_block_hashes = BlockHashListWithBlockSize(
|
||||
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
|
||||
)
|
||||
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
block_hashes=full_attention_block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=self.full_attention_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.full_attention_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.lcm_block_size,
|
||||
)
|
||||
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||
|
||||
# Next, find the cache hit for the other attention WITHIN
|
||||
# the cache hit of full attention.
|
||||
if self.other_spec.block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
other_block_hashes: BlockHashList = block_hashes
|
||||
else:
|
||||
# Similar to the full attention case, here we need to recalculate
|
||||
# block_hashes at the granularity of block_size, using the original
|
||||
# block_hashes (at the granularity of hash_block_size).
|
||||
other_block_hashes = BlockHashListWithBlockSize(
|
||||
block_hashes, self.hash_block_size, self.other_spec.block_size
|
||||
)
|
||||
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
block_hashes=other_block_hashes,
|
||||
max_length=hit_length,
|
||||
kv_cache_group_ids=self.other_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.other_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.lcm_block_size,
|
||||
)
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
@@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(
|
||||
@@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
dcp_world_size,
|
||||
pcp_world_size,
|
||||
hash_block_size,
|
||||
)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(
|
||||
@@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
dcp_world_size,
|
||||
pcp_world_size,
|
||||
hash_block_size,
|
||||
)
|
||||
return HybridKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
@@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
dcp_world_size,
|
||||
pcp_world_size,
|
||||
hash_block_size,
|
||||
)
|
||||
|
||||
@@ -95,6 +95,7 @@ class KVCacheManager:
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
hash_block_size: int,
|
||||
enable_caching: bool = True,
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
@@ -107,28 +108,11 @@ class KVCacheManager:
|
||||
self.enable_caching = enable_caching
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats. We still need
|
||||
# this comment because when the log stats is enabled there are still
|
||||
# potential configs we could expose in the future.
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
|
||||
self.block_size: int | None = None
|
||||
if self.enable_caching:
|
||||
assert (
|
||||
len(
|
||||
set(
|
||||
g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0
|
||||
].kv_cache_spec.block_size
|
||||
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
@@ -137,6 +121,7 @@ class KVCacheManager:
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
import copy
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NewType, TypeAlias
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, NewType, TypeAlias, overload
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
@@ -825,11 +825,11 @@ def get_num_blocks(
|
||||
return num_blocks
|
||||
|
||||
|
||||
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
||||
def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
|
||||
"""
|
||||
Get the page size of the KV cache.
|
||||
"""
|
||||
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
|
||||
assert len(page_sizes) == 1
|
||||
return page_sizes.pop()
|
||||
|
||||
@@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
|
||||
return len(page_sizes) == 1
|
||||
|
||||
|
||||
def unify_kv_cache_spec_page_size(
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Unify the page size of the given KVCacheSpec. If the page size of all layers
|
||||
are the same, return the original KVCacheSpec. If not same, unify the page
|
||||
size by increasing the block size of layers with smaller page size. Raise
|
||||
NotImplementedError if failed to unify the page size.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
|
||||
Returns:
|
||||
The updated KVCacheSpec with the same page_size_bytes.
|
||||
"""
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
if len(page_sizes) <= 1:
|
||||
# All layers have the same page size, no need to unify.
|
||||
return kv_cache_spec
|
||||
|
||||
max_page_size = max(page_sizes)
|
||||
new_kv_cache_spec = {}
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if layer_spec.page_size_bytes == max_page_size:
|
||||
new_kv_cache_spec[layer_name] = layer_spec
|
||||
else:
|
||||
layer_page_size = layer_spec.page_size_bytes
|
||||
if max_page_size % layer_page_size != 0:
|
||||
raise NotImplementedError(
|
||||
"The page size of the layer is not divisible by the "
|
||||
"maximum page size. Cannot unify by adjusting block_size."
|
||||
)
|
||||
ratio = max_page_size // layer_page_size
|
||||
new_block_size = layer_spec.block_size * ratio
|
||||
new_spec = replace(layer_spec, block_size=new_block_size)
|
||||
assert new_spec.page_size_bytes == max_page_size
|
||||
new_kv_cache_spec[layer_name] = new_spec
|
||||
return new_kv_cache_spec
|
||||
|
||||
|
||||
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
# kv_cache_spec is an empty dict for attention free models
|
||||
return not kv_cache_spec
|
||||
@@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
|
||||
def get_kv_cache_config_from_groups(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_cache_specs: dict[str, KVCacheSpec],
|
||||
available_memory: int,
|
||||
) -> KVCacheConfig:
|
||||
"""
|
||||
@@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups(
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_groups: The KV cache groups
|
||||
kv_cache_specs: The KV cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
@@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups(
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_specs)
|
||||
page_size = get_uniform_page_size(
|
||||
[group.kv_cache_spec for group in kv_cache_groups]
|
||||
)
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
num_blocks = get_num_blocks(
|
||||
vllm_config, group_size, available_memory, page_size
|
||||
@@ -1166,7 +1206,8 @@ def get_kv_cache_groups(
|
||||
# This returns an empty list to allow for the KVCacheManager to handle
|
||||
# attention free models.
|
||||
return []
|
||||
elif is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
|
||||
if is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
@@ -1176,14 +1217,16 @@ def get_kv_cache_groups(
|
||||
# full attention, or all layers are sliding window attention with the
|
||||
# same window size). Put all layers into one group.
|
||||
return _get_kv_cache_groups_uniform_type(uniform_spec)
|
||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
||||
|
||||
raise NotImplementedError
|
||||
# As KVCacheManager can only allocate memory of one size, we need to unify
|
||||
# the page size of the layers. For cases cannot be unified, this function
|
||||
# will raise an error.
|
||||
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
||||
|
||||
|
||||
def generate_scheduler_kv_cache_config(
|
||||
@@ -1327,10 +1370,7 @@ def get_kv_cache_configs(
|
||||
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
|
||||
kv_cache_configs.append(
|
||||
get_kv_cache_config_from_groups(
|
||||
vllm_config,
|
||||
kv_cache_groups_one_worker,
|
||||
kv_cache_spec_one_worker,
|
||||
available_memory_one_worker,
|
||||
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1353,3 +1393,79 @@ def get_kv_cache_configs(
|
||||
_report_kv_cache_config(vllm_config, kv_cache_config)
|
||||
|
||||
return kv_cache_configs
|
||||
|
||||
|
||||
class BlockHashListWithBlockSize:
|
||||
"""
|
||||
Convert block-hash granularity from `hash_block_size` to `target_block_size`.
|
||||
Used when KV cache groups have different block sizes: `hash_block_size`
|
||||
is the size used to compute the original `block_hashes`; `target_block_size`
|
||||
is the group's actual block size.
|
||||
|
||||
Currently, only scaling up by an integer factor is supported (i.e.,
|
||||
`target_block_size` is a multiple of `hash_block_size`). Conversion is
|
||||
performed lazily on access for efficiency, by concatenating consecutive
|
||||
hashes at `hash_block_size` to form each hash at `target_block_size`.
|
||||
|
||||
Example (`hash_block_size` = 16, `target_block_size` = 32):
|
||||
concatenating two 16-size hashes yields one 32-size hash:
|
||||
|
||||
Block hashes with block_size 16:
|
||||
| Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
|
||||
|-------------|------|-------|-------|-------|
|
||||
| Hash | A | B | C | D |
|
||||
|
||||
Block hashes with block_size 32:
|
||||
| Token Range | 0-31 | 32-63 |
|
||||
|-------------|------|-------|
|
||||
| Hash | AB | CD |
|
||||
|
||||
Args:
|
||||
block_hashes: Block hashes to convert, computed at `hash_block_size`.
|
||||
hash_block_size: Block size at which `block_hashes` were computed.
|
||||
target_block_size: Desired block size; must be a multiple of `hash_block_size`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
hash_block_size: int,
|
||||
target_block_size: int,
|
||||
):
|
||||
self.block_hashes = block_hashes
|
||||
assert target_block_size % hash_block_size == 0
|
||||
self.scale_factor = target_block_size // hash_block_size
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.block_hashes) // self.scale_factor
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: int) -> BlockHash: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: slice) -> list[BlockHash]: ...
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
return self._get_value_at(idx)
|
||||
|
||||
if isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
return [self._get_value_at(i) for i in range(start, stop, step)]
|
||||
|
||||
raise TypeError(f"Invalid index type: {type(idx)!r}")
|
||||
|
||||
def __iter__(self) -> Iterator[BlockHash]:
|
||||
for i in range(len(self)):
|
||||
yield self._get_value_at(i)
|
||||
|
||||
def _get_value_at(self, idx: int) -> BlockHash:
|
||||
base = idx * self.scale_factor
|
||||
end = base + self.scale_factor
|
||||
merged_hash: bytes = self.block_hashes[base]
|
||||
for i in range(base + 1, end):
|
||||
merged_hash += self.block_hashes[i]
|
||||
return BlockHash(merged_hash)
|
||||
|
||||
|
||||
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize
|
||||
|
||||
@@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
hash_block_size=self.block_size,
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections.abc import Sequence
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
@@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
be a multiple of this value (in tokens). By default, it should
|
||||
be set to the block_size.
|
||||
dcp_world_size: The world size of decode context parallelism.
|
||||
pcp_world_size: The world size of prefill context parallelism.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block
|
||||
@@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
|
||||
), (
|
||||
"FullAttentionManager can only be used for full attention "
|
||||
"and chunked local attention groups"
|
||||
@@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
else:
|
||||
break
|
||||
if use_eagle and computed_blocks[0]:
|
||||
# Need to drop the last matched block if eagle is enabled.
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
while (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||
):
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
@@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
[block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
block_size = kv_cache_spec.block_size
|
||||
num_contiguous_blocks = 0
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
@@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
# Skip prefix matching check if the block is not aligned with
|
||||
# `alignment_tokens`.
|
||||
if (
|
||||
num_contiguous_blocks == 0
|
||||
and block_size != alignment_tokens # Faster for common case.
|
||||
and (i + 1) * block_size % alignment_tokens != 0
|
||||
):
|
||||
continue
|
||||
# Add the cached block to the computed blocks.
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed[i] = cached
|
||||
num_contiguous_blocks += 1
|
||||
@@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
for computed in computed_blocks:
|
||||
del computed[num_contiguous_blocks:]
|
||||
while (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||
):
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
if use_eagle and computed_blocks[0]:
|
||||
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||
"aligned_length is not compatible with eagle now"
|
||||
)
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
@@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
dcp_world_size: The world size of decode context parallelism.
|
||||
pcp_world_size: The world size of prefill context parallelism.
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
be a multiple of this value (in tokens).
|
||||
|
||||
Returns:
|
||||
A list of cached blocks
|
||||
@@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
||||
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||
"KV cache groups with different block sizes are not compatible with "
|
||||
"chunked local attention now"
|
||||
)
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (
|
||||
@@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
block_size = kv_cache_spec.block_size
|
||||
max_num_blocks = max_length // block_size
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
# When enable Mamba prefix caching, `block_size` will be aligned
|
||||
# across full attention layers and Mamba layers to ensure the
|
||||
# prefix hit length aligned at block
|
||||
if (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and (i + 1) * block_size % alignment_tokens != 0
|
||||
):
|
||||
continue
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
# the hit length logic later assumes:
|
||||
# hit_length = len(hit_blocks_other_attn[0])
|
||||
@@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
|
||||
Reference in New Issue
Block a user