[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-02-28 18:20:11 -05:00
committed by GitHub
parent 084bbac8cc
commit bd56c983d6
9 changed files with 239 additions and 160 deletions

View File

@@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel):
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graphs. Default is .
- enable_fusion: whether to enable the custom fusion pass.
- enable_reshape: whether to enable the custom reshape elimination pass.
TODO better pass enabling system.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
"""
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_reshape: bool = True
enable_noop: bool = True
def uuid(self):
"""
@@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel):
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(
include={"enable_fusion", "enable_reshape"})
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).digest()
def model_post_init(self, __context: Any) -> None:
if not self.enable_reshape and self.enable_fusion:
if not self.enable_noop and self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work")
@@ -3411,7 +3410,7 @@ class VllmConfig:
self.compilation_config.use_inductor = True
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.pass_config.enable_noop = False
self.compilation_config.level = CompilationLevel.PIECEWISE
self._set_cudagraph_sizes()