[BugFix] Correct max memory usage for multiple KV-cache groups (#36030)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
This commit is contained in:
Harry Huang
2026-03-17 08:38:52 +08:00
committed by GitHub
parent 5db91f0aaf
commit 45f526d652
2 changed files with 45 additions and 2 deletions

View File

@@ -43,6 +43,7 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
@@ -157,6 +158,24 @@ def new_chunked_local_attention_spec(
)
def new_mamba_spec(
block_size=16,
shapes=((2, 512), (3, 32, 32)),
dtypes=(torch.float32, torch.float32),
num_speculative_blocks=2,
mamba_cache_mode="none",
page_size_padded=None,
):
return MambaSpec(
block_size=block_size,
shapes=shapes,
dtypes=dtypes,
page_size_padded=page_size_padded,
mamba_cache_mode=mamba_cache_mode,
num_speculative_blocks=num_speculative_blocks,
)
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
@@ -2010,6 +2029,28 @@ def test_auto_fit_max_model_len():
assert vllm_config.model_config.max_model_len > 0
def test_auto_fit_max_model_len_with_hybrid():
"""Test that auto-fit works with hybrid KV cache specs."""
# Create config with original_max_model_len=-1 to trigger auto-fit
model_config = ModelConfig(max_model_len=8192)
# Simulate the user passing -1 by setting original_max_model_len
model_config.original_max_model_len = -1
vllm_config = VllmConfig(model_config=model_config)
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer
gamma = 2
kv_cache_specs = {
"layer_1": new_mamba_spec(num_speculative_blocks=gamma),
"layer_2": new_kv_cache_spec(),
}
available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma)
_kv_cache_configs = get_kv_cache_configs(
vllm_config, [kv_cache_specs], [available_memory]
)
assert vllm_config.model_config.max_model_len == 1024
def test_auto_fit_max_model_len_not_triggered():
"""Test that auto-fit is not triggered when original_max_model_len is not -1."""
model_config = ModelConfig(max_model_len=16)

View File

@@ -1356,8 +1356,10 @@ def _max_memory_usage_bytes_from_groups(
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
any_spec = kv_cache_groups[0].kv_cache_spec
blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size)
blocks_needed = sum(
cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size)
for group in kv_cache_groups
)
return group_size * page_size * blocks_needed