[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -3804,9 +3804,10 @@ class PassConfig:
|
||||
its own stages (before, after, maybe in-between)."""
|
||||
dump_graph_dir: Path = Path(".")
|
||||
"""Directory to dump the graphs."""
|
||||
# TODO(luka) better pass enabling system.
|
||||
enable_fusion: bool = True
|
||||
"""Whether to enable the custom fusion pass."""
|
||||
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
||||
enable_attn_fusion: bool = False
|
||||
"""Whether to enable the custom attention+quant fusion pass."""
|
||||
enable_noop: bool = True
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
enable_sequence_parallelism: bool = False
|
||||
@@ -3814,6 +3815,8 @@ class PassConfig:
|
||||
enable_async_tp: bool = False
|
||||
"""Whether to enable async TP."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
Produces a hash unique to the pass configuration.
|
||||
@@ -3821,18 +3824,20 @@ class PassConfig:
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
include = {
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism",
|
||||
"enable_async_tp"
|
||||
}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k in include}
|
||||
exclude = {"dump_graph_stages", "dump_graph_dir"}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k not in exclude}
|
||||
return InductorPass.hash_dict(dict_)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
if not self.enable_noop:
|
||||
if self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm/SiluMul + quant (fp8) fusion might not work")
|
||||
if self.enable_attn_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"Attention + quant (fp8) fusion might not work")
|
||||
|
||||
|
||||
@config
|
||||
|
||||
Reference in New Issue
Block a user