[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-06-12 11:31:04 -04:00
committed by GitHub
parent 96846bb360
commit f98548b9da
33 changed files with 622 additions and 79 deletions

View File

@@ -10,6 +10,7 @@ from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .fusion_attn import AttnFusionPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
@@ -59,6 +60,9 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):