[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-02-28 18:20:11 -05:00
committed by GitHub
parent 084bbac8cc
commit bd56c983d6
9 changed files with 239 additions and 160 deletions

View File

@@ -13,21 +13,26 @@ class TestBackend:
This class provides a simple Inductor backend that can be used for testing.
It takes a list of custom passes and runs them after Inductor's passes.
It also saves the graph before and after the custom passes for inspection.
Inductor config can be modified directly by editing the inductor_config
property. This can be helpful for adding passes like the
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
"""
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
self.custom_passes = list(passes)
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass
self.inductor_config = config.shallow_copy_dict()
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.current_config)
config_patches=self.inductor_config)
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)