[Feature] support sequence parallelism using compilation pass (#16155)
Signed-off-by: cascade812 <cascade812@outlook.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -531,7 +531,7 @@ class FusionPass(VllmInductorPass):
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: CompilationConfig.PassConfig):
|
||||
def instance(cls, config: VllmConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
@@ -540,10 +540,10 @@ class FusionPass(VllmInductorPass):
|
||||
if cls._instance is None:
|
||||
cls._instance = FusionPass(config)
|
||||
else:
|
||||
cls._instance.config = config
|
||||
cls._instance.pass_config = config.compilation_config.pass_config
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: CompilationConfig.PassConfig):
|
||||
def __init__(self, config: VllmConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
|
||||
Reference in New Issue
Block a user