[V1] Implement sliding window attention in kv_cache_manager (#14097)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
|
||||
@@ -43,28 +44,23 @@ class KVCacheSpec:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
"""
|
||||
The KV cache size for `num_tokens` tokens in bytes. Returns the real
|
||||
memory size after padding `num_tokens` to full blocks.
|
||||
The maximum possible memory usage of this KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The KV cache size
|
||||
The KV cache size in bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullAttentionSpec(KVCacheSpec):
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
@@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
@dataclass
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# During chunked prefill, we allocate KV cache for the last
|
||||
# `self.sliding_window-1` computed tokens plus the newly scheduled
|
||||
# tokens. And we won't allocate KV cache for more than `max_model_len`
|
||||
# tokens.
|
||||
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
|
||||
max_model_len)
|
||||
|
||||
# +1 here because the sliding window may not start from the beginning
|
||||
# of the block. For example, if the block size is 4 and num_token
|
||||
# is 4, we need two blocks [XXCD] [EF] to store the sliding
|
||||
# window [CDEF] of 6 tokens.
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user