Integration SM100 FlashInfer fused allreduce RMSNorm (#20691)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .collective_fusion import AsyncTPPass
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
@@ -62,7 +62,11 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [
|
||||
AllReduceFusionPass(
|
||||
config, self.pass_config.fi_allreduce_fusion_max_token_num)
|
||||
]
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
|
||||
Reference in New Issue
Block a user