[TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (#19919)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
@@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
|
||||
vllm_config = get_vllm_config()
|
||||
vllm_config.model_config.max_model_len = 32000
|
||||
vllm_config.scheduler_config.max_num_seqs = 1200
|
||||
model_runner = get_model_runner(vllm_config)
|
||||
|
||||
# verify model runner will adjust num_reqs to avoid SMEM OOM.
|
||||
assert model_runner.num_reqs_most_model_len == 1200
|
||||
# num_page_per_req = 32k // 128
|
||||
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
|
||||
assert model_runner.num_reqs_max_model_len == 524
|
||||
|
||||
Reference in New Issue
Block a user