[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-02-28 18:20:11 -05:00
committed by GitHub
parent 084bbac8cc
commit bd56c983d6
9 changed files with 239 additions and 160 deletions

View File

@@ -11,7 +11,7 @@ from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .reshapes import RedundantReshapesPass
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
@@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (RedundantReshapesPass, FusionPass)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
@@ -54,8 +54,8 @@ class PostGradPassManager(Parent):
def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_reshape:
self.passes += [RedundantReshapesPass(pass_config)]
if pass_config.enable_noop:
self.passes += [NoOpEliminationPass(pass_config)]
if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]