[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -28,11 +28,22 @@ class CompilerInterface:
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -166,7 +177,10 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
@@ -242,18 +256,23 @@ class InductorAdaptor(CompilerInterface):
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
def initialize_cache(self,
|
||||
cache_dir: str,
|
||||
disable_cache: bool = False,
|
||||
prefix: str = ""):
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a sub-directory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(cache_dir, "inductor_cache")
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(cache_dir, "triton_cache")
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
@@ -298,14 +317,14 @@ class InductorAdaptor(CompilerInterface):
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.cache_dir):
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.cache_dir):
|
||||
self.base_cache_dir):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
break
|
||||
@@ -325,14 +344,15 @@ class InductorAdaptor(CompilerInterface):
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.cache_dir):
|
||||
if not file_path.startswith(self.base_cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.cache_dir):
|
||||
if code.co_filename.startswith(
|
||||
self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
|
||||
Reference in New Issue
Block a user