[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel):
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graphs. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
- enable_reshape: whether to enable the custom reshape elimination pass.
|
||||
TODO better pass enabling system.
|
||||
- enable_noop: whether to enable the custom no-op elimination pass.
|
||||
TODO(luka) better pass enabling system.
|
||||
"""
|
||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
enable_reshape: bool = True
|
||||
enable_noop: bool = True
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
@@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel):
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
dict_ = self.model_dump(
|
||||
include={"enable_fusion", "enable_reshape"})
|
||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).digest()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not self.enable_reshape and self.enable_fusion:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
@@ -3411,7 +3410,7 @@ class VllmConfig:
|
||||
self.compilation_config.use_inductor = True
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
self.compilation_config.pass_config.enable_fusion = False
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
self.compilation_config.pass_config.enable_noop = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
Reference in New Issue
Block a user