[torch.compile] consider relevant code in compilation cache (#11614)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-08 18:46:43 +08:00
committed by GitHub
parent cfd3219f58
commit f12141170a
4 changed files with 99 additions and 35 deletions

View File

@@ -3,7 +3,6 @@ import copy
import enum
import hashlib
import json
import os
import sys
import warnings
from contextlib import contextmanager
@@ -2778,9 +2777,8 @@ class CompilationConfig(BaseModel):
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
traced_files: Set[str] = PrivateAttr
compilation_time: float = PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache: Any = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
@@ -2818,6 +2816,7 @@ class CompilationConfig(BaseModel):
"compilation_time",
"bs_to_padded_graph_size",
"pass_config",
"traced_files",
}
return self.model_dump_json(exclude=exclude, exclude_unset=True)
@@ -2877,6 +2876,7 @@ class CompilationConfig(BaseModel):
self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
self.traced_files = set()
self.static_forward_context = {}
self.compilation_time = 0.0
@@ -2899,29 +2899,6 @@ class CompilationConfig(BaseModel):
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if not self.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key = vllm_config.compute_hash()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir
disabled = envs.VLLM_DISABLE_COMPILE_CACHE
from vllm.compilation.backends import InductorHashCache
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
self.cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info(
"Using cache directory: %s for vLLM's torch.compile",
self.cache_dir)
from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)