[Feature] support sequence parallelism using compilation pass (#16155)
Signed-off-by: cascade812 <cascade812@outlook.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
# dummy custom pass that doesn't inherit
|
||||
@@ -16,7 +16,7 @@ def simple_callable(graph: torch.fx.Graph):
|
||||
|
||||
# Should fail to add directly to the pass manager
|
||||
def test_bad_callable():
|
||||
config = CompilationConfig().pass_config
|
||||
config = VllmConfig()
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
@@ -43,7 +43,7 @@ class ProperPass(InductorPass):
|
||||
],
|
||||
)
|
||||
def test_pass_manager_uuid(callable):
|
||||
config = CompilationConfig().pass_config
|
||||
config = VllmConfig()
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
@@ -64,7 +64,8 @@ def test_pass_manager_uuid(callable):
|
||||
|
||||
# UUID should be different due to config change
|
||||
config2 = copy.deepcopy(config)
|
||||
config2.enable_fusion = not config2.enable_fusion
|
||||
config2.compilation_config.pass_config.enable_fusion = not \
|
||||
config2.compilation_config.pass_config.enable_fusion
|
||||
pass_manager3 = PostGradPassManager()
|
||||
pass_manager3.configure(config2)
|
||||
pass_manager3.add(callable)
|
||||
|
||||
Reference in New Issue
Block a user