[Feature] Estimate max-model-len use available KV cache memory (#16168)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
@@ -3,14 +3,16 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.utils import GiB_bytes, sha256
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
PrefixCachingMetrics,
|
||||
estimate_max_model_len,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens,
|
||||
@@ -426,3 +428,45 @@ def test_unify_kv_cache_configs():
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
unify_kv_cache_configs(diff_kv_cache_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "max_model_len", "want_estimated_max_len"), [
|
||||
("Qwen/Qwen1.5-7B", 16385, 16384),
|
||||
("Qwen/Qwen1.5-7B", 16383, 16383),
|
||||
])
|
||||
def test_estimate_max_model_len(model_id, max_model_len,
|
||||
want_estimated_max_len):
|
||||
# Create a VllmConfig
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="generate",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
|
||||
# Create KV cache specs
|
||||
kv_cache_spec = {}
|
||||
for i in range(32):
|
||||
layer_name = f"layer_{i}"
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
# Estimate the maximum model length, 16384 model_len need 8GB
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
8 * GiB_bytes)
|
||||
assert estimated_max_len == want_estimated_max_len
|
||||
|
||||
Reference in New Issue
Block a user