[Feature]Add async tensor parallelism using compilation pass (#17882)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@@ -3652,6 +3652,8 @@ class PassConfig:
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
enable_sequence_parallelism: bool = False
|
||||
"""Whether to enable sequence parallelism."""
|
||||
enable_async_tp: bool = False
|
||||
"""Whether to enable async TP."""
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
@@ -3661,7 +3663,8 @@ class PassConfig:
|
||||
compilation.
|
||||
"""
|
||||
include = {
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism"
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism",
|
||||
"enable_async_tp"
|
||||
}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k in include}
|
||||
return InductorPass.hash_dict(dict_)
|
||||
@@ -4274,6 +4277,12 @@ class VllmConfig:
|
||||
|
||||
if self.compilation_config is None:
|
||||
self.compilation_config = CompilationConfig()
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism = \
|
||||
True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||
|
||||
Reference in New Issue
Block a user