[compile] aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#36358)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-03-11 06:12:03 -04:00
committed by GitHub
parent c87fb515ed
commit 09b6f99852
4 changed files with 182 additions and 43 deletions

View File

@@ -4,6 +4,7 @@
import functools
import hashlib
import multiprocessing
import os
import pickle
import tempfile
from contextlib import contextmanager
@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
StandaloneCompiledArtifacts,
VllmSerializableFunction,
)
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
assert isinstance(config, dict)
assert "bundled_autograd_cache" in config
assert config["bundled_autograd_cache"] is True
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_disable_compile_cache_skips_aot_save(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved."""
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=0,
),
):
mod = CompiledMod(vllm_config=vllm_config)
actual = mod(*args)
assert torch.allclose(actual, expected)
# No cached artifact should exist on disk
aot_dir = os.path.join(fresh_vllm_cache, "torch_compile_cache", "torch_aot_compile")
if os.path.isdir(aot_dir):
for root, _dirs, files in os.walk(aot_dir):
for f in files:
assert f != "model", (
f"AOT artifact unexpectedly saved at {os.path.join(root, f)}"
)
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_disable_compile_cache_skips_aot_load(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded."""
# Phase 1: compile and save with cache enabled
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(num_aot_artifacts_saved=1),
):
CompiledMod(vllm_config=vllm_config)(*args)
# Phase 2: disable cache, compile again — should NOT load from disk
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
disable_envs_cache()
torch._dynamo.reset()
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=0,
),
):
mod = CompiledMod(vllm_config=vllm_config)
mod(*args)
assert not mod.was_aot_compile_fn_loaded_from_disk
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_aot_counters_on_save_and_load(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""Verify AOT counters are incremented correctly on save and load."""
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
# Phase 1: fresh compile + save
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=1,
num_aot_artifacts_loaded=0,
),
):
CompiledMod(vllm_config=vllm_config)(*args)
# Phase 2: load from cache
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
disable_envs_cache()
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=0,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=1,
),
):
CompiledMod(vllm_config=vllm_config)(*args)

View File

@@ -31,6 +31,12 @@ class CompilationCounter:
num_compiled_artifacts_saved: int = 0
# The number of standalone_compile compiled artifacts loaded from cache
num_compiled_artifacts_loaded: int = 0
# The number of AOT compile invocations
num_aot_compiles: int = 0
# The number of AOT compiled artifacts saved to disk
num_aot_artifacts_saved: int = 0
# The number of AOT compiled artifacts loaded from disk
num_aot_artifacts_loaded: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0

View File

@@ -266,6 +266,51 @@ def _verify_source_unchanged(
)
def _try_load_aot_compiled_fn(
model: Any,
aot_compilation_path: str,
) -> Any | None:
"""Try to load an AOT-compiled function from disk.
Returns the loaded callable on success, or None on failure.
Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
"""
try:
with monitor_torch_compile(model.vllm_config):
with (
set_current_vllm_config(model.vllm_config),
open(aot_compilation_path, "rb") as f,
):
loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=model.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), model.vllm_config)
ds_config = model.compilation_config.dynamic_shapes_config
if not ds_config.evaluate_guards:
loaded_fn.disable_guard_check()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with maybe_use_cudagraph_partition_wrapper(model.vllm_config):
loaded_fn._artifacts.compiled_fn.finalize_loading(model.vllm_config)
compilation_counter.num_aot_artifacts_loaded += 1
logger.info("Directly load AOT compilation from path %s", aot_compilation_path)
return loaded_fn
except Exception as e:
if os.path.exists(aot_compilation_path):
if isinstance(e, EOFError):
message = "Compile cache file corrupted."
else:
message = str(e)
logger.warning(
"Compiling model again due to a load failure from %s, reason: %s",
aot_compilation_path,
message,
)
if envs.VLLM_FORCE_AOT_LOAD:
raise e
return None
def _support_torch_compile(
cls: type[_T],
dynamic_arg_dims: dict[str, int | list[int]],
@@ -438,51 +483,17 @@ def _support_torch_compile(
dp_rank = self.vllm_config.parallel_config.data_parallel_index
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model")
try:
with monitor_torch_compile(self.vllm_config):
if not envs.VLLM_DISABLE_COMPILE_CACHE:
loaded_fn = _try_load_aot_compiled_fn(self, aot_compilation_path)
if loaded_fn is not None:
self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True
with (
set_current_vllm_config(self.vllm_config),
open(aot_compilation_path, "rb") as f,
monitor_profiling_run(),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
):
loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=self.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
ds_config = self.compilation_config.dynamic_shapes_config
if not ds_config.evaluate_guards:
loaded_fn.disable_guard_check()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
loaded_fn._artifacts.compiled_fn.finalize_loading(
self.vllm_config
)
self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True
except Exception as e:
if os.path.exists(aot_compilation_path):
if isinstance(e, EOFError):
message = "Compile cache file corrupted."
else:
message = str(e)
logger.warning(
"Compiling model again due to a load failure from %s, "
"reason: %s",
aot_compilation_path,
message,
)
if envs.VLLM_FORCE_AOT_LOAD:
raise e
if getattr(self, "aot_compiled_fn", None) is not None:
logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path
)
with (
monitor_profiling_run(),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
):
output = self.aot_compiled_fn(self, *args, **kwargs)
return output
output = self.aot_compiled_fn(self, *args, **kwargs)
return output
if self.compiled:
assert (
@@ -570,6 +581,7 @@ def _support_torch_compile(
self._aot_cache_dir = cache_dir
with monitor_torch_compile(self.vllm_config):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
compilation_counter.num_aot_compiles += 1
# All compilation is done at this point, save the
# AOT artifact.
self.save_aot_compiled_function()
@@ -593,6 +605,9 @@ def _support_torch_compile(
# triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self: type[_T]) -> None:
if envs.VLLM_DISABLE_COMPILE_CACHE:
return
if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save")
return
@@ -608,6 +623,7 @@ def _support_torch_compile(
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
self.aot_compiled_fn.save_compiled_function(tmp_file)
os.replace(tmp_file, self._aot_compilation_path)
compilation_counter.num_aot_artifacts_saved += 1
logger.info_once(
"saved AOT compiled function to %s",
self._aot_compilation_path,

View File

@@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None:
compilation_counter.num_cache_entries_updated = 0
compilation_counter.num_compiled_artifacts_saved = 0
compilation_counter.stock_torch_compile_count = 0
compilation_counter.num_aot_compiles = 0
compilation_counter.num_aot_artifacts_saved = 0
compilation_counter.num_aot_artifacts_loaded = 0
# Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py