Integration SM100 FlashInfer fused allreduce RMSNorm (#20691)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
Ilya Markov
2025-07-12 03:58:15 +02:00
committed by GitHub
parent 7b828e30d5
commit fc0f41d10a
4 changed files with 514 additions and 6 deletions

View File

@@ -7,7 +7,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from .activation_quant_fusion import ActivationQuantFusionPass
from .collective_fusion import AsyncTPPass
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .fusion_attn import AttnFusionPass
@@ -62,7 +62,11 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [
AllReduceFusionPass(
config, self.pass_config.fi_allreduce_fusion_max_token_num)
]
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):