[Hybrid Allocator] Support full attention with different hidden size (#25101)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -234,6 +234,76 @@ class CrossAttentionSpec(AttentionSpec):
|
||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
"""
|
||||
A KV cache spec for multiple layers with the same type of attention. Here,
|
||||
same types means always need the same number of token slots. For example,
|
||||
sliding window attentions with different window sizes are not the same type
|
||||
and should not be merged into one UniformTypeKVCacheSpecs.
|
||||
"""
|
||||
kv_cache_specs: dict[str, KVCacheSpec]
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return sum(spec.page_size_bytes
|
||||
for spec in self.kv_cache_specs.values())
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_num_pages = max(
|
||||
cdiv(spec.max_memory_usage_bytes(vllm_config),
|
||||
spec.page_size_bytes)
|
||||
for spec in self.kv_cache_specs.values())
|
||||
return max_num_pages * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers have the same type of KV cache spec.
|
||||
"""
|
||||
block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
|
||||
if len(block_sizes) > 1:
|
||||
# Different block sizes, not uniform.
|
||||
return False
|
||||
one_spec = next(iter(kv_cache_specs.values()))
|
||||
if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)):
|
||||
return all(
|
||||
isinstance(spec, type(one_spec))
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, SlidingWindowSpec):
|
||||
return all(
|
||||
isinstance(spec, SlidingWindowSpec)
|
||||
and spec.sliding_window == one_spec.sliding_window
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, ChunkedLocalAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, ChunkedLocalAttentionSpec)
|
||||
and spec.attention_chunk_size == one_spec.attention_chunk_size
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, MambaSpec):
|
||||
return all(
|
||||
isinstance(spec, MambaSpec) and spec.num_speculative_blocks ==
|
||||
one_spec.num_speculative_blocks
|
||||
for spec in kv_cache_specs.values())
|
||||
else:
|
||||
# NOTE(Chen): Please add new branches for new KV cache spec types.
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec type: {type(one_spec)}")
|
||||
|
||||
@classmethod
|
||||
def from_specs(cls, kv_cache_specs: dict[str,
|
||||
KVCacheSpec]) -> Optional[Self]:
|
||||
"""
|
||||
Return a SameTypeKVCacheSpecs object if all layers have the same type
|
||||
of KV cache spec. Return None if not.
|
||||
"""
|
||||
if cls.is_uniform_type(kv_cache_specs):
|
||||
block_size = next(iter(kv_cache_specs.values())).block_size
|
||||
return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user