[compile] aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#36358)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-03-11 06:12:03 -04:00
committed by GitHub
parent c87fb515ed
commit 09b6f99852
4 changed files with 182 additions and 43 deletions

View File

@@ -4,6 +4,7 @@
import functools
import hashlib
import multiprocessing
import os
import pickle
import tempfile
from contextlib import contextmanager
@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
StandaloneCompiledArtifacts,
VllmSerializableFunction,
)
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
assert isinstance(config, dict)
assert "bundled_autograd_cache" in config
assert config["bundled_autograd_cache"] is True
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_disable_compile_cache_skips_aot_save(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved."""
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=0,
),
):
mod = CompiledMod(vllm_config=vllm_config)
actual = mod(*args)
assert torch.allclose(actual, expected)
# No cached artifact should exist on disk
aot_dir = os.path.join(fresh_vllm_cache, "torch_compile_cache", "torch_aot_compile")
if os.path.isdir(aot_dir):
for root, _dirs, files in os.walk(aot_dir):
for f in files:
assert f != "model", (
f"AOT artifact unexpectedly saved at {os.path.join(root, f)}"
)
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_disable_compile_cache_skips_aot_load(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded."""
# Phase 1: compile and save with cache enabled
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(num_aot_artifacts_saved=1),
):
CompiledMod(vllm_config=vllm_config)(*args)
# Phase 2: disable cache, compile again — should NOT load from disk
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
disable_envs_cache()
torch._dynamo.reset()
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=0,
),
):
mod = CompiledMod(vllm_config=vllm_config)
mod(*args)
assert not mod.was_aot_compile_fn_loaded_from_disk
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_aot_counters_on_save_and_load(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""Verify AOT counters are incremented correctly on save and load."""
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
# Phase 1: fresh compile + save
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=1,
num_aot_artifacts_loaded=0,
),
):
CompiledMod(vllm_config=vllm_config)(*args)
# Phase 2: load from cache
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
disable_envs_cache()
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=0,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=1,
),
):
CompiledMod(vllm_config=vllm_config)(*args)