2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
2025-01-17 15:39:35 +08:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import Dict, List
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
|
from vllm.utils import cdiv, get_dtype_size
|
|
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class KVCacheSpecBase:
|
|
|
|
|
"""
|
|
|
|
|
A base class for specifying the KV cache format of one layer.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# number of tokens in a block
|
|
|
|
|
block_size: int
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type_id(self) -> str:
|
|
|
|
|
"""
|
|
|
|
|
The type identifier of this KV cache.
|
|
|
|
|
Return different strings for layers with different KV cache type (e.g.,
|
|
|
|
|
different number of tokens like full attention vs sliding window
|
|
|
|
|
attention, different KV cache size per token like layers with different
|
|
|
|
|
number of heads)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The type identifier of this KV cache.
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def page_size_bytes(self) -> int:
|
|
|
|
|
"""
|
|
|
|
|
The size of a page with `block_size` tokens in bytes.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The page size
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def bytes_for_tokens(self, num_tokens: int) -> int:
|
|
|
|
|
"""
|
|
|
|
|
The KV cache size for `num_tokens` tokens in bytes. Returns the real
|
|
|
|
|
memory size after padding `num_tokens` to full blocks.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The KV cache size
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class FullAttentionSpec(KVCacheSpecBase):
|
|
|
|
|
num_kv_heads: int
|
|
|
|
|
head_size: int
|
|
|
|
|
dtype: torch.dtype
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type_id(self) -> str:
|
|
|
|
|
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def page_size_bytes(self) -> int:
|
|
|
|
|
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
|
|
|
|
* get_dtype_size(self.dtype)
|
|
|
|
|
|
|
|
|
|
def bytes_for_tokens(self, num_tokens: int) -> int:
|
|
|
|
|
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
KVCacheSpec = Dict[str, KVCacheSpecBase]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class KVCacheTensor:
|
|
|
|
|
"""
|
|
|
|
|
A dataclass for specifying how the workers should initialize the KV cache
|
|
|
|
|
for a layer. Only contains the size of KV cache for that layer for now. Will
|
|
|
|
|
be extended to support multiple layers sharing the same memory pool.
|
|
|
|
|
"""
|
|
|
|
|
size: int # The size of KV cache Tensor in bytes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class KVCacheConfig:
|
|
|
|
|
"""
|
|
|
|
|
The KV cache configuration of a model.
|
|
|
|
|
"""
|
|
|
|
|
"""The number of KV cache blocks"""
|
|
|
|
|
num_blocks: int
|
|
|
|
|
"""layer_name -> how to initialize KV cache for that layer"""
|
|
|
|
|
tensors: Dict[str, KVCacheTensor]
|
|
|
|
|
"""
|
|
|
|
|
A list of kv-cache groups. Each group includes a set of layers with
|
|
|
|
|
the same kv-cache spec, and the total page_size of layers inside a group
|
|
|
|
|
is same across all groups (as the KVCacheManager only supports allocating
|
|
|
|
|
pages of the same size). For example:
|
|
|
|
|
1. A model only uses full attention: one group with all layers in the model.
|
|
|
|
|
2. (not implemented yet) A model with the same number of full attention
|
|
|
|
|
layers and sliding window attention layers: two groups, one for full
|
|
|
|
|
attention layers and one for sliding window attention layers.
|
|
|
|
|
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
|
|
|
|
|
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
|
|
|
|
|
"""
|
|
|
|
|
groups: List[List[str]]
|
|
|
|
|
"""the KVCacheSpec of the model"""
|
|
|
|
|
kv_cache_spec: KVCacheSpec
|