[torch.compile] Inductor code caching fix (#10273)

Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
This commit is contained in:
Luka Govedič
2024-11-21 00:44:57 -05:00
committed by GitHub
parent 9d827170a3
commit 8b0fe06c89
14 changed files with 602 additions and 286 deletions

View File

@@ -6,10 +6,11 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass, is_func
logger = init_logger(__name__)
@@ -90,8 +91,6 @@ def empty_fp32(*args, **kwargs):
# Utilities for post-processing multi-output matches
def is_func(node: torch.fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
# Returns the first auto_functionalized node with the given op (if it exists)
@@ -127,7 +126,7 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
return ret
class FusionPass(InductorPass):
class FusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
@@ -142,7 +141,7 @@ class FusionPass(InductorPass):
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: CompilationConfig):
def instance(cls, config: CompilationConfig.PassConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
@@ -154,7 +153,7 @@ class FusionPass(InductorPass):
cls._instance.config = config
return cls._instance
def __init__(self, config: CompilationConfig):
def __init__(self, config: CompilationConfig.PassConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)
@@ -278,6 +277,7 @@ class FusionPass(InductorPass):
for node in match.nodes)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_fusion")
count = self.patterns.apply(graph)
@@ -289,3 +289,4 @@ class FusionPass(InductorPass):
logger.debug("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_fusion")
self.matches.clear()
self.end_and_log()