From d1a6e96d9e0b76cc2a0af33e014b4bd8b860f1e4 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 2 Mar 2026 14:27:06 -0500 Subject: [PATCH] [torch.compile] Improve cold and warm start compile tests (#35709) Signed-off-by: Richard Zou --- tests/compile/test_cold_start.py | 48 ----------------- tests/compile/test_startup.py | 71 ++++++++++++++++++++++++++ tests/conftest.py | 8 +++ vllm/compilation/compiler_interface.py | 1 + vllm/compilation/counter.py | 2 + 5 files changed, 82 insertions(+), 48 deletions(-) delete mode 100644 tests/compile/test_cold_start.py create mode 100644 tests/compile/test_startup.py diff --git a/tests/compile/test_cold_start.py b/tests/compile/test_cold_start.py deleted file mode 100644 index 5482b4c9a..000000000 --- a/tests/compile/test_cold_start.py +++ /dev/null @@ -1,48 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from torch._dynamo.utils import counters - -from vllm import LLM -from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode - - -def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache): - # Run in same process so we can access PyTorch's internal counters - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - - # I'm not sure if this is going to affect the numbers - monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0") - - # Force cold compilation - monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - - compilation_config = CompilationConfig( - mode=CompilationMode.VLLM_COMPILE, - cudagraph_mode=CUDAGraphMode.NONE, # make the model loading faster - ) - - counters.clear() - - _ = LLM( - model="microsoft/Phi-tiny-MoE-instruct", - max_model_len=256, - load_format="dummy", # make the model loading faster - compilation_config=compilation_config, - num_gpu_blocks_override=8, # make the model loading faster - ) - - # vLLM-compile cold start is special. By default, we do - # one full dynamo capture of the entire forward pass. - # The forward pass consists of 32 transformer layers. - # Then, we split on the attention operation. This results in - # 33 subgraphs (not including the attention operation). - # We then generate compiled artifacts for the unique subgraphs. - # - # There are actually only 3 unique subgraphs for this model - # (all of its transformer layers are the same modulo weights); - # this is true for most vLLM models. - # So we test that during cold start, we are only compling - # for 3 unique subgraphs. - assert counters["aot_autograd"]["autograd_cache_miss"] == 3 - assert counters["aot_autograd"]["autograd_cache_hit"] == 0 diff --git a/tests/compile/test_startup.py b/tests/compile/test_startup.py new file mode 100644 index 000000000..acdce9d0b --- /dev/null +++ b/tests/compile/test_startup.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Cold start and warm start tests for vLLM-compile. + +Cold start runs in a forked child (must fork before CUDA init) which +populates on-disk caches and asserts cold-start counters. Warm start +then runs in the parent with clean in-memory state but populated caches. +""" + +import multiprocessing as mp + +from torch._dynamo.utils import counters + +from vllm.compilation.counter import compilation_counter +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode + +MODEL = "microsoft/Phi-tiny-MoE-instruct" + + +def _run_vllm(vllm_runner): + with vllm_runner( + MODEL, + trust_remote_code=False, + max_model_len=256, + max_num_batched_tokens=1024, + load_format="dummy", + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + cudagraph_mode=CUDAGraphMode.NONE, + ), + num_gpu_blocks_override=8, + ): + pass + + +def _cold_start(vllm_runner): + counters.clear() + with compilation_counter.expect( + num_compiled_artifacts_saved=3, + num_compiled_artifacts_loaded=0, + ): + _run_vllm(vllm_runner) + assert counters["aot_autograd"]["total"] == 33 + assert counters["aot_autograd"]["autograd_cache_miss"] == 3 + assert counters["aot_autograd"]["autograd_cache_hit"] == 0 + + +def test_moe_startup(monkeypatch, vllm_runner, fresh_vllm_cache): + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # Cold start in a forked child (must fork before CUDA init). + # This model has 32 identical transformer layers which produce + # 33 subgraphs after splitting on attention — only 3 are unique. + ctx = mp.get_context("fork") + p = ctx.Process(target=_cold_start, args=(vllm_runner,)) + p.start() + p.join() + assert p.exitcode == 0, "Cold-start child failed" + + # Warm start — compiled artifacts loaded from disk cache. + counters.clear() + with compilation_counter.expect( + num_compiled_artifacts_loaded=3, + # TODO: warm start should not save any artifacts + # https://github.com/vllm-project/vllm/issues/35708 + num_compiled_artifacts_saved=1, + ): + _run_vllm(vllm_runner) + assert counters["aot_autograd"]["total"] == 30 + assert counters["aot_autograd"]["autograd_cache_miss"] == 0 + assert counters["aot_autograd"]["autograd_cache_hit"] == 1 diff --git a/tests/conftest.py b/tests/conftest.py index 5a2beea89..164cbeee2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1548,6 +1548,14 @@ def use_fresh_inductor_cache(): yield +@pytest.fixture +def fresh_vllm_cache(monkeypatch, use_fresh_inductor_cache): + """Temporary VLLM_CACHE_ROOT combined with a fresh inductor cache.""" + with tempfile.TemporaryDirectory() as tmp_dir: + monkeypatch.setenv("VLLM_CACHE_ROOT", tmp_dir) + yield tmp_dir + + @pytest.fixture(scope="function") def enable_pickle(monkeypatch): """`LLM.apply_model` requires pickling a function.""" diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index e021ce9e3..e7748e380 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -368,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface): inductor_compiled_graph = torch._inductor.CompiledArtifact.load( path=path, format=self.save_format ) + compilation_counter.num_compiled_artifacts_loaded += 1 from torch._inductor.compile_fx import graph_returns_tuple returns_tuple = graph_returns_tuple(graph) diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 29d3045aa..2ed49b9e3 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -29,6 +29,8 @@ class CompilationCounter: num_cache_entries_updated: int = 0 # The number of standalone_compile compiled artifacts saved num_compiled_artifacts_saved: int = 0 + # The number of standalone_compile compiled artifacts loaded from cache + num_compiled_artifacts_loaded: int = 0 # Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE stock_torch_compile_count: int = 0