[compile] aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#36358)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
|
||||
StandaloneCompiledArtifacts,
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
|
||||
assert isinstance(config, dict)
|
||||
assert "bundled_autograd_cache" in config
|
||||
assert config["bundled_autograd_cache"] is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_disable_compile_cache_skips_aot_save(
|
||||
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
|
||||
):
|
||||
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved."""
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
args = (torch.randn(10, 10),)
|
||||
expected = reference_fn(*args)
|
||||
vllm_config = make_vllm_config()
|
||||
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=1,
|
||||
num_aot_artifacts_saved=0,
|
||||
num_aot_artifacts_loaded=0,
|
||||
),
|
||||
):
|
||||
mod = CompiledMod(vllm_config=vllm_config)
|
||||
actual = mod(*args)
|
||||
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
# No cached artifact should exist on disk
|
||||
aot_dir = os.path.join(fresh_vllm_cache, "torch_compile_cache", "torch_aot_compile")
|
||||
if os.path.isdir(aot_dir):
|
||||
for root, _dirs, files in os.walk(aot_dir):
|
||||
for f in files:
|
||||
assert f != "model", (
|
||||
f"AOT artifact unexpectedly saved at {os.path.join(root, f)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_disable_compile_cache_skips_aot_load(
|
||||
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
|
||||
):
|
||||
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded."""
|
||||
# Phase 1: compile and save with cache enabled
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
args = (torch.randn(10, 10),)
|
||||
vllm_config = make_vllm_config()
|
||||
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(num_aot_artifacts_saved=1),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
# Phase 2: disable cache, compile again — should NOT load from disk
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
disable_envs_cache()
|
||||
torch._dynamo.reset()
|
||||
|
||||
vllm_config = make_vllm_config()
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=1,
|
||||
num_aot_artifacts_saved=0,
|
||||
num_aot_artifacts_loaded=0,
|
||||
),
|
||||
):
|
||||
mod = CompiledMod(vllm_config=vllm_config)
|
||||
mod(*args)
|
||||
|
||||
assert not mod.was_aot_compile_fn_loaded_from_disk
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_aot_counters_on_save_and_load(
|
||||
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
|
||||
):
|
||||
"""Verify AOT counters are incremented correctly on save and load."""
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
args = (torch.randn(10, 10),)
|
||||
|
||||
# Phase 1: fresh compile + save
|
||||
vllm_config = make_vllm_config()
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=1,
|
||||
num_aot_artifacts_saved=1,
|
||||
num_aot_artifacts_loaded=0,
|
||||
),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
# Phase 2: load from cache
|
||||
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
vllm_config = make_vllm_config()
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=0,
|
||||
num_aot_artifacts_saved=0,
|
||||
num_aot_artifacts_loaded=1,
|
||||
),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
Reference in New Issue
Block a user