diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 4772ef4c9..9f6a1a13e 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -4,6 +4,7 @@ import functools import hashlib import multiprocessing +import os import pickle import tempfile from contextlib import contextmanager @@ -19,6 +20,7 @@ from vllm.compilation.caching import ( StandaloneCompiledArtifacts, VllmSerializableFunction, ) +from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CompilationConfig, @@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration: assert isinstance(config, dict) assert "bundled_autograd_cache" in config assert config["bundled_autograd_cache"] is True + + +@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +def test_disable_compile_cache_skips_aot_save( + monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str +): + """When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved.""" + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1") + disable_envs_cache() + + args = (torch.randn(10, 10),) + expected = reference_fn(*args) + vllm_config = make_vllm_config() + + with ( + use_vllm_config(vllm_config), + compilation_counter.expect( + num_aot_compiles=1, + num_aot_artifacts_saved=0, + num_aot_artifacts_loaded=0, + ), + ): + mod = CompiledMod(vllm_config=vllm_config) + actual = mod(*args) + + assert torch.allclose(actual, expected) + + # No cached artifact should exist on disk + aot_dir = os.path.join(fresh_vllm_cache, "torch_compile_cache", "torch_aot_compile") + if os.path.isdir(aot_dir): + for root, _dirs, files in os.walk(aot_dir): + for f in files: + assert f != "model", ( + f"AOT artifact unexpectedly saved at {os.path.join(root, f)}" + ) + + +@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +def test_disable_compile_cache_skips_aot_load( + monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str +): + """When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded.""" + # Phase 1: compile and save with cache enabled + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1") + disable_envs_cache() + + args = (torch.randn(10, 10),) + vllm_config = make_vllm_config() + + with ( + use_vllm_config(vllm_config), + compilation_counter.expect(num_aot_artifacts_saved=1), + ): + CompiledMod(vllm_config=vllm_config)(*args) + + # Phase 2: disable cache, compile again — should NOT load from disk + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + disable_envs_cache() + torch._dynamo.reset() + + vllm_config = make_vllm_config() + with ( + use_vllm_config(vllm_config), + compilation_counter.expect( + num_aot_compiles=1, + num_aot_artifacts_saved=0, + num_aot_artifacts_loaded=0, + ), + ): + mod = CompiledMod(vllm_config=vllm_config) + mod(*args) + + assert not mod.was_aot_compile_fn_loaded_from_disk + + +@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +def test_aot_counters_on_save_and_load( + monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str +): + """Verify AOT counters are incremented correctly on save and load.""" + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1") + disable_envs_cache() + + args = (torch.randn(10, 10),) + + # Phase 1: fresh compile + save + vllm_config = make_vllm_config() + with ( + use_vllm_config(vllm_config), + compilation_counter.expect( + num_aot_compiles=1, + num_aot_artifacts_saved=1, + num_aot_artifacts_loaded=0, + ), + ): + CompiledMod(vllm_config=vllm_config)(*args) + + # Phase 2: load from cache + monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1") + disable_envs_cache() + + vllm_config = make_vllm_config() + with ( + use_vllm_config(vllm_config), + compilation_counter.expect( + num_aot_compiles=0, + num_aot_artifacts_saved=0, + num_aot_artifacts_loaded=1, + ), + ): + CompiledMod(vllm_config=vllm_config)(*args) diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 2ed49b9e3..fd62e558d 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -31,6 +31,12 @@ class CompilationCounter: num_compiled_artifacts_saved: int = 0 # The number of standalone_compile compiled artifacts loaded from cache num_compiled_artifacts_loaded: int = 0 + # The number of AOT compile invocations + num_aot_compiles: int = 0 + # The number of AOT compiled artifacts saved to disk + num_aot_artifacts_saved: int = 0 + # The number of AOT compiled artifacts loaded from disk + num_aot_artifacts_loaded: int = 0 # Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE stock_torch_compile_count: int = 0 diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index d52d45708..da32bef73 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -266,6 +266,51 @@ def _verify_source_unchanged( ) +def _try_load_aot_compiled_fn( + model: Any, + aot_compilation_path: str, +) -> Any | None: + """Try to load an AOT-compiled function from disk. + + Returns the loaded callable on success, or None on failure. + Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set. + """ + try: + with monitor_torch_compile(model.vllm_config): + with ( + set_current_vllm_config(model.vllm_config), + open(aot_compilation_path, "rb") as f, + ): + loaded_fn = torch.compiler.load_compiled_function( + f, f_globals=model.forward.__globals__ + ) + _verify_source_unchanged(loaded_fn.source_info(), model.vllm_config) + ds_config = model.compilation_config.dynamic_shapes_config + if not ds_config.evaluate_guards: + loaded_fn.disable_guard_check() + # Eagerly load compiled artifacts now that traced_files + # is populated by _verify_source_unchanged. + with maybe_use_cudagraph_partition_wrapper(model.vllm_config): + loaded_fn._artifacts.compiled_fn.finalize_loading(model.vllm_config) + compilation_counter.num_aot_artifacts_loaded += 1 + logger.info("Directly load AOT compilation from path %s", aot_compilation_path) + return loaded_fn + except Exception as e: + if os.path.exists(aot_compilation_path): + if isinstance(e, EOFError): + message = "Compile cache file corrupted." + else: + message = str(e) + logger.warning( + "Compiling model again due to a load failure from %s, reason: %s", + aot_compilation_path, + message, + ) + if envs.VLLM_FORCE_AOT_LOAD: + raise e + return None + + def _support_torch_compile( cls: type[_T], dynamic_arg_dims: dict[str, int | list[int]], @@ -438,51 +483,17 @@ def _support_torch_compile( dp_rank = self.vllm_config.parallel_config.data_parallel_index cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") aot_compilation_path = os.path.join(cache_dir, "model") - try: - with monitor_torch_compile(self.vllm_config): + if not envs.VLLM_DISABLE_COMPILE_CACHE: + loaded_fn = _try_load_aot_compiled_fn(self, aot_compilation_path) + if loaded_fn is not None: + self.aot_compiled_fn = loaded_fn + self.was_aot_compile_fn_loaded_from_disk = True with ( - set_current_vllm_config(self.vllm_config), - open(aot_compilation_path, "rb") as f, + monitor_profiling_run(), + maybe_use_cudagraph_partition_wrapper(self.vllm_config), ): - loaded_fn = torch.compiler.load_compiled_function( - f, f_globals=self.forward.__globals__ - ) - _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) - ds_config = self.compilation_config.dynamic_shapes_config - if not ds_config.evaluate_guards: - loaded_fn.disable_guard_check() - # Eagerly load compiled artifacts now that traced_files - # is populated by _verify_source_unchanged. - with maybe_use_cudagraph_partition_wrapper(self.vllm_config): - loaded_fn._artifacts.compiled_fn.finalize_loading( - self.vllm_config - ) - self.aot_compiled_fn = loaded_fn - self.was_aot_compile_fn_loaded_from_disk = True - except Exception as e: - if os.path.exists(aot_compilation_path): - if isinstance(e, EOFError): - message = "Compile cache file corrupted." - else: - message = str(e) - logger.warning( - "Compiling model again due to a load failure from %s, " - "reason: %s", - aot_compilation_path, - message, - ) - if envs.VLLM_FORCE_AOT_LOAD: - raise e - if getattr(self, "aot_compiled_fn", None) is not None: - logger.info( - "Directly load AOT compilation from path %s", aot_compilation_path - ) - with ( - monitor_profiling_run(), - maybe_use_cudagraph_partition_wrapper(self.vllm_config), - ): - output = self.aot_compiled_fn(self, *args, **kwargs) - return output + output = self.aot_compiled_fn(self, *args, **kwargs) + return output if self.compiled: assert ( @@ -570,6 +581,7 @@ def _support_torch_compile( self._aot_cache_dir = cache_dir with monitor_torch_compile(self.vllm_config): self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + compilation_counter.num_aot_compiles += 1 # All compilation is done at this point, save the # AOT artifact. self.save_aot_compiled_function() @@ -593,6 +605,9 @@ def _support_torch_compile( # triggers VllmSerializableFunction.serialize() def save_aot_compiled_function(self: type[_T]) -> None: + if envs.VLLM_DISABLE_COMPILE_CACHE: + return + if self.was_aot_compile_fn_loaded_from_disk: logger.debug("AOT compiled function was loaded from cache, skipping save") return @@ -608,6 +623,7 @@ def _support_torch_compile( tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp" self.aot_compiled_fn.save_compiled_function(tmp_file) os.replace(tmp_file, self._aot_compilation_path) + compilation_counter.num_aot_artifacts_saved += 1 logger.info_once( "saved AOT compiled function to %s", self._aot_compilation_path, diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 5dff296d0..c6f6072bd 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None: compilation_counter.num_cache_entries_updated = 0 compilation_counter.num_compiled_artifacts_saved = 0 compilation_counter.stock_torch_compile_count = 0 + compilation_counter.num_aot_compiles = 0 + compilation_counter.num_aot_artifacts_saved = 0 + compilation_counter.num_aot_artifacts_loaded = 0 # Clear the AOT compiled function so the model is forced to # recompile on the next call. Without this, decorators.py