[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from vllm.logger import init_logger
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import InductorPass
|
||||
from .reshapes import RedundantReshapesPass
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (RedundantReshapesPass, FusionPass)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
@@ -54,8 +54,8 @@ class PostGradPassManager(Parent):
|
||||
|
||||
def configure(self, pass_config: CompilationConfig.PassConfig):
|
||||
self.pass_config = pass_config
|
||||
if pass_config.enable_reshape:
|
||||
self.passes += [RedundantReshapesPass(pass_config)]
|
||||
if pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(pass_config)]
|
||||
|
||||
if pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(pass_config)]
|
||||
|
||||
Reference in New Issue
Block a user