[TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (#19919)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang
2025-06-25 15:51:02 -07:00
committed by GitHub
parent 55c65ab495
commit 2d7620c3eb
5 changed files with 184 additions and 76 deletions

View File

@@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
vllm_config = get_vllm_config()
vllm_config.model_config.max_model_len = 32000
vllm_config.scheduler_config.max_num_seqs = 1200
model_runner = get_model_runner(vllm_config)
# verify model runner will adjust num_reqs to avoid SMEM OOM.
assert model_runner.num_reqs_most_model_len == 1200
# num_page_per_req = 32k // 128
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
assert model_runner.num_reqs_max_model_len == 524