[compile] aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#36358)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
|
||||
StandaloneCompiledArtifacts,
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
|
||||
assert isinstance(config, dict)
|
||||
assert "bundled_autograd_cache" in config
|
||||
assert config["bundled_autograd_cache"] is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_disable_compile_cache_skips_aot_save(
|
||||
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
|
||||
):
|
||||
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved."""
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
args = (torch.randn(10, 10),)
|
||||
expected = reference_fn(*args)
|
||||
vllm_config = make_vllm_config()
|
||||
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=1,
|
||||
num_aot_artifacts_saved=0,
|
||||
num_aot_artifacts_loaded=0,
|
||||
),
|
||||
):
|
||||
mod = CompiledMod(vllm_config=vllm_config)
|
||||
actual = mod(*args)
|
||||
|
||||
assert torch.allclose(actual, expected)
|
||||
|
||||
# No cached artifact should exist on disk
|
||||
aot_dir = os.path.join(fresh_vllm_cache, "torch_compile_cache", "torch_aot_compile")
|
||||
if os.path.isdir(aot_dir):
|
||||
for root, _dirs, files in os.walk(aot_dir):
|
||||
for f in files:
|
||||
assert f != "model", (
|
||||
f"AOT artifact unexpectedly saved at {os.path.join(root, f)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_disable_compile_cache_skips_aot_load(
|
||||
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
|
||||
):
|
||||
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded."""
|
||||
# Phase 1: compile and save with cache enabled
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
args = (torch.randn(10, 10),)
|
||||
vllm_config = make_vllm_config()
|
||||
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(num_aot_artifacts_saved=1),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
# Phase 2: disable cache, compile again — should NOT load from disk
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
disable_envs_cache()
|
||||
torch._dynamo.reset()
|
||||
|
||||
vllm_config = make_vllm_config()
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=1,
|
||||
num_aot_artifacts_saved=0,
|
||||
num_aot_artifacts_loaded=0,
|
||||
),
|
||||
):
|
||||
mod = CompiledMod(vllm_config=vllm_config)
|
||||
mod(*args)
|
||||
|
||||
assert not mod.was_aot_compile_fn_loaded_from_disk
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_aot_counters_on_save_and_load(
|
||||
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
|
||||
):
|
||||
"""Verify AOT counters are incremented correctly on save and load."""
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
args = (torch.randn(10, 10),)
|
||||
|
||||
# Phase 1: fresh compile + save
|
||||
vllm_config = make_vllm_config()
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=1,
|
||||
num_aot_artifacts_saved=1,
|
||||
num_aot_artifacts_loaded=0,
|
||||
),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
# Phase 2: load from cache
|
||||
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
disable_envs_cache()
|
||||
|
||||
vllm_config = make_vllm_config()
|
||||
with (
|
||||
use_vllm_config(vllm_config),
|
||||
compilation_counter.expect(
|
||||
num_aot_compiles=0,
|
||||
num_aot_artifacts_saved=0,
|
||||
num_aot_artifacts_loaded=1,
|
||||
),
|
||||
):
|
||||
CompiledMod(vllm_config=vllm_config)(*args)
|
||||
|
||||
@@ -31,6 +31,12 @@ class CompilationCounter:
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# The number of standalone_compile compiled artifacts loaded from cache
|
||||
num_compiled_artifacts_loaded: int = 0
|
||||
# The number of AOT compile invocations
|
||||
num_aot_compiles: int = 0
|
||||
# The number of AOT compiled artifacts saved to disk
|
||||
num_aot_artifacts_saved: int = 0
|
||||
# The number of AOT compiled artifacts loaded from disk
|
||||
num_aot_artifacts_loaded: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
|
||||
@@ -266,6 +266,51 @@ def _verify_source_unchanged(
|
||||
)
|
||||
|
||||
|
||||
def _try_load_aot_compiled_fn(
|
||||
model: Any,
|
||||
aot_compilation_path: str,
|
||||
) -> Any | None:
|
||||
"""Try to load an AOT-compiled function from disk.
|
||||
|
||||
Returns the loaded callable on success, or None on failure.
|
||||
Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
|
||||
"""
|
||||
try:
|
||||
with monitor_torch_compile(model.vllm_config):
|
||||
with (
|
||||
set_current_vllm_config(model.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
):
|
||||
loaded_fn = torch.compiler.load_compiled_function(
|
||||
f, f_globals=model.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), model.vllm_config)
|
||||
ds_config = model.compilation_config.dynamic_shapes_config
|
||||
if not ds_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
# Eagerly load compiled artifacts now that traced_files
|
||||
# is populated by _verify_source_unchanged.
|
||||
with maybe_use_cudagraph_partition_wrapper(model.vllm_config):
|
||||
loaded_fn._artifacts.compiled_fn.finalize_loading(model.vllm_config)
|
||||
compilation_counter.num_aot_artifacts_loaded += 1
|
||||
logger.info("Directly load AOT compilation from path %s", aot_compilation_path)
|
||||
return loaded_fn
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
if isinstance(e, EOFError):
|
||||
message = "Compile cache file corrupted."
|
||||
else:
|
||||
message = str(e)
|
||||
logger.warning(
|
||||
"Compiling model again due to a load failure from %s, reason: %s",
|
||||
aot_compilation_path,
|
||||
message,
|
||||
)
|
||||
if envs.VLLM_FORCE_AOT_LOAD:
|
||||
raise e
|
||||
return None
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: type[_T],
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
@@ -438,51 +483,17 @@ def _support_torch_compile(
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_index
|
||||
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
aot_compilation_path = os.path.join(cache_dir, "model")
|
||||
try:
|
||||
with monitor_torch_compile(self.vllm_config):
|
||||
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||
loaded_fn = _try_load_aot_compiled_fn(self, aot_compilation_path)
|
||||
if loaded_fn is not None:
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
self.was_aot_compile_fn_loaded_from_disk = True
|
||||
with (
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
monitor_profiling_run(),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
):
|
||||
loaded_fn = torch.compiler.load_compiled_function(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
ds_config = self.compilation_config.dynamic_shapes_config
|
||||
if not ds_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
# Eagerly load compiled artifacts now that traced_files
|
||||
# is populated by _verify_source_unchanged.
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
loaded_fn._artifacts.compiled_fn.finalize_loading(
|
||||
self.vllm_config
|
||||
)
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
self.was_aot_compile_fn_loaded_from_disk = True
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
if isinstance(e, EOFError):
|
||||
message = "Compile cache file corrupted."
|
||||
else:
|
||||
message = str(e)
|
||||
logger.warning(
|
||||
"Compiling model again due to a load failure from %s, "
|
||||
"reason: %s",
|
||||
aot_compilation_path,
|
||||
message,
|
||||
)
|
||||
if envs.VLLM_FORCE_AOT_LOAD:
|
||||
raise e
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
with (
|
||||
monitor_profiling_run(),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
):
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
return output
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
return output
|
||||
|
||||
if self.compiled:
|
||||
assert (
|
||||
@@ -570,6 +581,7 @@ def _support_torch_compile(
|
||||
self._aot_cache_dir = cache_dir
|
||||
with monitor_torch_compile(self.vllm_config):
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
compilation_counter.num_aot_compiles += 1
|
||||
# All compilation is done at this point, save the
|
||||
# AOT artifact.
|
||||
self.save_aot_compiled_function()
|
||||
@@ -593,6 +605,9 @@ def _support_torch_compile(
|
||||
|
||||
# triggers VllmSerializableFunction.serialize()
|
||||
def save_aot_compiled_function(self: type[_T]) -> None:
|
||||
if envs.VLLM_DISABLE_COMPILE_CACHE:
|
||||
return
|
||||
|
||||
if self.was_aot_compile_fn_loaded_from_disk:
|
||||
logger.debug("AOT compiled function was loaded from cache, skipping save")
|
||||
return
|
||||
@@ -608,6 +623,7 @@ def _support_torch_compile(
|
||||
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
|
||||
self.aot_compiled_fn.save_compiled_function(tmp_file)
|
||||
os.replace(tmp_file, self._aot_compilation_path)
|
||||
compilation_counter.num_aot_artifacts_saved += 1
|
||||
logger.info_once(
|
||||
"saved AOT compiled function to %s",
|
||||
self._aot_compilation_path,
|
||||
|
||||
@@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None:
|
||||
compilation_counter.num_cache_entries_updated = 0
|
||||
compilation_counter.num_compiled_artifacts_saved = 0
|
||||
compilation_counter.stock_torch_compile_count = 0
|
||||
compilation_counter.num_aot_compiles = 0
|
||||
compilation_counter.num_aot_artifacts_saved = 0
|
||||
compilation_counter.num_aot_artifacts_loaded = 0
|
||||
|
||||
# Clear the AOT compiled function so the model is forced to
|
||||
# recompile on the next call. Without this, decorators.py
|
||||
|
||||
Reference in New Issue
Block a user