[torch.compile] Inductor code caching fix (#10273)
Signed-off-by: luka <luka@neuralmagic.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
This commit is contained in:
@@ -6,10 +6,11 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
||||
fwd_only, register_replacement)
|
||||
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass, is_func
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -90,8 +91,6 @@ def empty_fp32(*args, **kwargs):
|
||||
|
||||
|
||||
# Utilities for post-processing multi-output matches
|
||||
def is_func(node: torch.fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
@@ -127,7 +126,7 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
|
||||
return ret
|
||||
|
||||
|
||||
class FusionPass(InductorPass):
|
||||
class FusionPass(VllmInductorPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
@@ -142,7 +141,7 @@ class FusionPass(InductorPass):
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: CompilationConfig):
|
||||
def instance(cls, config: CompilationConfig.PassConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
@@ -154,7 +153,7 @@ class FusionPass(InductorPass):
|
||||
cls._instance.config = config
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
def __init__(self, config: CompilationConfig.PassConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
@@ -278,6 +277,7 @@ class FusionPass(InductorPass):
|
||||
for node in match.nodes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
@@ -289,3 +289,4 @@ class FusionPass(InductorPass):
|
||||
logger.debug("Post-processed %s matches", len(self.matches))
|
||||
self.dump_graph(graph, "after_fusion")
|
||||
self.matches.clear()
|
||||
self.end_and_log()
|
||||
|
||||
Reference in New Issue
Block a user