[Feature] support sequence parallelism using compilation pass (#16155)

Signed-off-by: cascade812 <cascade812@outlook.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
cascade
2025-04-27 06:29:35 -07:00
committed by GitHub
parent ed7a29d9f8
commit 690fe019f0
21 changed files with 1072 additions and 44 deletions

View File

@@ -4,13 +4,15 @@ from typing import List
from torch import fx as fx
from vllm.config import CompilationConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import CustomGraphPass, InductorPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
from .noop_elimination import NoOpEliminationPass
from .sequence_parallelism import SequenceParallelismPass
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
@@ -31,24 +33,29 @@ class PostGradPassManager(CustomGraphPass):
"""
def __init__(self):
self.passes: List[InductorPass] = []
self.passes: List[VllmInductorPass] = []
def __call__(self, graph: fx.Graph):
shape = get_pass_context().runtime_shape
for pass_ in self.passes:
pass_(graph)
if pass_.is_applicable_for_shape(shape):
pass_(graph)
# always run fix_functionalization last
self.fix_functionalization(graph)
def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_noop:
self.passes += [NoOpEliminationPass(pass_config)]
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.fix_functionalization = FixFunctionalizationPass(pass_config)
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)