diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 08463a280..8153fed69 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -43,6 +43,7 @@ from vllm.v1.kv_cache_interface import ( KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, + MambaSpec, MLAAttentionSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, @@ -157,6 +158,24 @@ def new_chunked_local_attention_spec( ) +def new_mamba_spec( + block_size=16, + shapes=((2, 512), (3, 32, 32)), + dtypes=(torch.float32, torch.float32), + num_speculative_blocks=2, + mamba_cache_mode="none", + page_size_padded=None, +): + return MambaSpec( + block_size=block_size, + shapes=shapes, + dtypes=dtypes, + page_size_padded=page_size_padded, + mamba_cache_mode=mamba_cache_mode, + num_speculative_blocks=num_speculative_blocks, + ) + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -2010,6 +2029,28 @@ def test_auto_fit_max_model_len(): assert vllm_config.model_config.max_model_len > 0 +def test_auto_fit_max_model_len_with_hybrid(): + """Test that auto-fit works with hybrid KV cache specs.""" + # Create config with original_max_model_len=-1 to trigger auto-fit + model_config = ModelConfig(max_model_len=8192) + # Simulate the user passing -1 by setting original_max_model_len + model_config.original_max_model_len = -1 + vllm_config = VllmConfig(model_config=model_config) + + mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer + gamma = 2 + kv_cache_specs = { + "layer_1": new_mamba_spec(num_speculative_blocks=gamma), + "layer_2": new_kv_cache_spec(), + } + + available_memory = mem_per_block_per_layer * (1024 // 16 + 1 + gamma) + _kv_cache_configs = get_kv_cache_configs( + vllm_config, [kv_cache_specs], [available_memory] + ) + assert vllm_config.model_config.max_model_len == 1024 + + def test_auto_fit_max_model_len_not_triggered(): """Test that auto-fit is not triggered when original_max_model_len is not -1.""" model_config = ModelConfig(max_model_len=16) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3da3d7e7b..83ada0530 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1356,8 +1356,10 @@ def _max_memory_usage_bytes_from_groups( page_size = get_uniform_page_size( [group.kv_cache_spec for group in kv_cache_groups] ) - any_spec = kv_cache_groups[0].kv_cache_spec - blocks_needed = cdiv(any_spec.max_memory_usage_bytes(vllm_config), page_size) + blocks_needed = sum( + cdiv(group.kv_cache_spec.max_memory_usage_bytes(vllm_config), page_size) + for group in kv_cache_groups + ) return group_size * page_size * blocks_needed