[Feature] support sequence parallelism using compilation pass (#16155)
Signed-off-by: cascade812 <cascade812@outlook.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -4,13 +4,15 @@ from typing import List
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: List[InductorPass] = []
|
||||
self.passes: List[VllmInductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
shape = get_pass_context().runtime_shape
|
||||
for pass_ in self.passes:
|
||||
pass_(graph)
|
||||
if pass_.is_applicable_for_shape(shape):
|
||||
pass_(graph)
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
|
||||
def configure(self, pass_config: CompilationConfig.PassConfig):
|
||||
self.pass_config = pass_config
|
||||
if pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(pass_config)]
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(pass_config)]
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(pass_config)
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
|
||||
Reference in New Issue
Block a user