[Feature]Add async tensor parallelism using compilation pass (#17882)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .collective_fusion import AsyncTPPass
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
@@ -54,6 +55,8 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user