Enable V1 for Hybrid SSM/Attention Models (#20016)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
@@ -268,9 +269,13 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
assert self.other_block_size % self.full_attention_block_size == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full attention "
|
||||
"layers is divisible by other layers now.")
|
||||
|
||||
if self.enable_caching:
|
||||
# this requirement is only needed for the prefix caching logic
|
||||
divisible = self.other_block_size % self.full_attention_block_size
|
||||
assert divisible == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full "
|
||||
"attention layers is divisible by other layers now.")
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
|
||||
@@ -84,12 +84,15 @@ class KVCacheManager:
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.block_size: Optional[int] = None
|
||||
if self.enable_caching:
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
@@ -154,6 +157,7 @@ class KVCacheManager:
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
if not block_hashes:
|
||||
assert self.block_size is not None
|
||||
block_hashes = hash_request_tokens(self.caching_hash_fn,
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
@@ -864,9 +864,11 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
min_block_size = min(
|
||||
[group.kv_cache_spec.block_size for group in kv_cache_groups])
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(
|
||||
grouped_layers) * vllm_config.cache_config.block_size
|
||||
num_tokens = num_blocks // len(grouped_layers) * min_block_size
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
|
||||
Reference in New Issue
Block a user