[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-06-12 11:31:04 -04:00
committed by GitHub
parent 96846bb360
commit f98548b9da
33 changed files with 622 additions and 79 deletions

View File

@@ -3804,9 +3804,10 @@ class PassConfig:
its own stages (before, after, maybe in-between)."""
dump_graph_dir: Path = Path(".")
"""Directory to dump the graphs."""
# TODO(luka) better pass enabling system.
enable_fusion: bool = True
"""Whether to enable the custom fusion pass."""
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False
"""Whether to enable the custom attention+quant fusion pass."""
enable_noop: bool = True
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False
@@ -3814,6 +3815,8 @@ class PassConfig:
enable_async_tp: bool = False
"""Whether to enable async TP."""
# TODO(luka) better pass enabling system.
def uuid(self):
"""
Produces a hash unique to the pass configuration.
@@ -3821,18 +3824,20 @@ class PassConfig:
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
include = {
"enable_fusion", "enable_noop", "enable_sequence_parallelism",
"enable_async_tp"
}
dict_ = {k: v for k, v in asdict(self).items() if k in include}
exclude = {"dump_graph_stages", "dump_graph_dir"}
dict_ = {k: v for k, v in asdict(self).items() if k not in exclude}
return InductorPass.hash_dict(dict_)
def __post_init__(self) -> None:
if not self.enable_noop and self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work")
if not self.enable_noop:
if self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm/SiluMul + quant (fp8) fusion might not work")
if self.enable_attn_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"Attention + quant (fp8) fusion might not work")
@config