[ez] Delete more torch version checks <= 2.8 (#33288)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
@@ -375,14 +375,13 @@ class InductorAdaptor(CompilerInterface):
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args: Any, **kwargs: Any) -> Any:
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
@@ -395,44 +394,14 @@ class InductorAdaptor(CompilerInterface):
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
if cell.cell_contents.__code__.co_filename.startswith(
|
||||
self.base_cache_dir
|
||||
):
|
||||
# this is the real file path compiled from Inductor
|
||||
file_path = cell.cell_contents.__code__.co_filename
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
return inductor_compiled_graph
|
||||
|
||||
hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa
|
||||
elif torch.__version__ >= "2.6":
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if (
|
||||
not file_path.startswith(self.base_cache_dir)
|
||||
and compiled_fn.__closure__ is not None
|
||||
):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
@@ -453,10 +422,6 @@ class InductorAdaptor(CompilerInterface):
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# hijack to get the compiled graph itself
|
||||
if original_load_name is not None:
|
||||
stack.enter_context(patch(original_load_name, hijack_load))
|
||||
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch(
|
||||
@@ -573,25 +538,16 @@ class InductorAdaptor(CompilerInterface):
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
if torch.__version__.startswith("2.5"):
|
||||
inductor_compiled_graph = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, False
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove "
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
elif torch.__version__ >= "2.6":
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove "
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove "
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
|
||||
Reference in New Issue
Block a user