[v1] torch.compile integration explanation (#14437)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-03-08 01:55:50 +08:00
committed by GitHub
parent 952a074980
commit c6359e8ca6
3 changed files with 154 additions and 1 deletions

View File

@@ -155,6 +155,7 @@ class InductorAdaptor(CompilerInterface):
triton_cache = os.path.join(cache_dir, "triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
self.cache_dir = cache_dir
def compile(
self,
@@ -200,7 +201,19 @@ class InductorAdaptor(CompilerInterface):
def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
compiled_fn = inductor_compiled_graph.current_callable
file_path = compiled_fn.__code__.co_filename # noqa
if not file_path.startswith(self.cache_dir):
# hooked in the align_inputs_from_check_idxs function
# in torch/_inductor/utils.py
for cell in compiled_fn.__closure__:
if not callable(cell.cell_contents):
continue
if cell.cell_contents.__code__.co_filename.startswith(
self.cache_dir):
# this is the real file path compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename
break
return inductor_compiled_graph
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa