[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-06-13 15:23:25 +08:00
committed by GitHub
parent ce688ad46e
commit d70bc7c029
6 changed files with 117 additions and 27 deletions

View File

@@ -28,11 +28,22 @@ class CompilerInterface:
# This is a class-level attribute.
name: str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
e.g. by re-directing its own cache directory to a sub-directory.
prefix can be used in combination with cache_dir to figure out the base
cache directory, e.g. there're multiple parts of model being compiled,
but we want to share the same cache directory for all of them.
e.g.
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
"""
pass
@@ -166,7 +177,10 @@ class InductorStandaloneAdaptor(CompilerInterface):
usedforsecurity=False).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
self.cache_dir = cache_dir
def compile(
@@ -242,18 +256,23 @@ class InductorAdaptor(CompilerInterface):
usedforsecurity=False).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
if disable_cache:
return
# redirect the cache directory to a sub-directory
# set flags so that Inductor and Triton store their cache
# in the cache_dir, then users only need to copy the cache_dir
# to another machine to reuse the cache.
inductor_cache = os.path.join(cache_dir, "inductor_cache")
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
os.makedirs(inductor_cache, exist_ok=True)
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
triton_cache = os.path.join(cache_dir, "triton_cache")
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
@@ -298,14 +317,14 @@ class InductorAdaptor(CompilerInterface):
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir):
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
if cell.cell_contents.__code__.co_filename.startswith(
self.cache_dir):
self.base_cache_dir):
# this is the real file path compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename
break
@@ -325,14 +344,15 @@ class InductorAdaptor(CompilerInterface):
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir):
if not file_path.startswith(self.base_cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(self.cache_dir):
if code.co_filename.startswith(
self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename