[Core] Support Local Chunked Attention for Hybrid KV Cache (#19351)

Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Lu Fang <fanglu@meta.com>
This commit is contained in:
Lucia Fang
2025-07-19 11:48:38 +08:00
committed by GitHub
parent 466e878f2a
commit 9a9fda1423
9 changed files with 351 additions and 19 deletions

View File

@@ -87,6 +87,7 @@ class AttentionSpec(KVCacheSpec):
@dataclass
class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
@@ -105,6 +106,17 @@ class FullAttentionSpec(AttentionSpec):
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size.")
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
@@ -114,14 +126,17 @@ class FullAttentionSpec(AttentionSpec):
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
if len(sliding_window) == 0:
merged_spec.sliding_window = None
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop()
else:
raise ValueError(
"All sliding window layers in the same KV cache group "
"must have the same window size.")
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
merged_spec.attention_chunk_size = (
cls.merge_window_sizes(attention_chunk_size))
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
), ("Model with both sliding window layers and chunked local attention "
"layers is not supported.")
return merged_spec
@@ -129,16 +144,26 @@ class FullAttentionSpec(AttentionSpec):
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@property
def type_id(self) -> str:
return (
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
) # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for at most
# `self.attention_chunk_size` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.attention_chunk_size + max_num_batched_tokens,
max_model_len)
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass
class SlidingWindowSpec(AttentionSpec):