[BugFix] Correct max memory usage for multiple KV-cache groups (#36030)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
KVCacheTensor,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
@@ -157,6 +158,24 @@ def new_chunked_local_attention_spec(
|
||||
)
|
||||
|
||||
|
||||
def new_mamba_spec(
|
||||
block_size=16,
|
||||
shapes=((2, 512), (3, 32, 32)),
|
||||
dtypes=(torch.float32, torch.float32),
|
||||
num_speculative_blocks=2,
|
||||
mamba_cache_mode="none",
|
||||
page_size_padded=None,
|
||||
):
|
||||
return MambaSpec(
|
||||
block_size=block_size,
|
||||
shapes=shapes,
|
||||
dtypes=dtypes,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_cache_mode=mamba_cache_mode,
|
||||
num_speculative_blocks=num_speculative_blocks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_none_hash(monkeypatch, hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
@@ -2010,6 +2029,28 @@ def test_auto_fit_max_model_len():
|
||||
assert vllm_config.model_config.max_model_len > 0
|
||||
|
||||
|
||||
def test_auto_fit_max_model_len_with_hybrid():
|
||||
"""Test that auto-fit works with hybrid KV cache specs."""
|
||||
# Create config with original_max_model_len=-1 to trigger auto-fit
|
||||
model_config = ModelConfig(max_model_len=8192)
|
||||
# Simulate the user passing -1 by setting original_max_model_len
|
||||
model_config.original_max_model_len = -1
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer
|
||||
gamma = 2
|
||||
kv_cache_specs = {
|
||||
"layer_1": new_mamba_spec(num_speculative_blocks=gamma),
|
||||
"layer_2": new_kv_cache_spec(),
|
||||
}
|
||||
|
||||
available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma)
|
||||
_kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs], [available_memory]
|
||||
)
|
||||
assert vllm_config.model_config.max_model_len == 1024
|
||||
|
||||
|
||||
def test_auto_fit_max_model_len_not_triggered():
|
||||
"""Test that auto-fit is not triggered when original_max_model_len is not -1."""
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
|
||||
@@ -1356,8 +1356,10 @@ def _max_memory_usage_bytes_from_groups(
|
||||
page_size = get_uniform_page_size(
|
||||
[group.kv_cache_spec for group in kv_cache_groups]
|
||||
)
|
||||
any_spec = kv_cache_groups[0].kv_cache_spec
|
||||
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
|
||||
blocks_needed = sum(
|
||||
cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)
|
||||
for group in kv_cache_groups
|
||||
)
|
||||
|
||||
return group_size * page_size * blocks_needed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user