[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
@@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||
enable_reshape=True)
|
||||
reshape_pass = RedundantReshapesPass(config)
|
||||
enable_noop=True)
|
||||
noop_pass = NoOpEliminationPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
|
||||
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
|
||||
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
||||
func_pass = FixFunctionalizationPass(config)
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
Reference in New Issue
Block a user