[Feature] support sequence parallelism using compilation pass (#16155)

Signed-off-by: cascade812 <cascade812@outlook.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
cascade
2025-04-27 06:29:35 -07:00
committed by GitHub
parent ed7a29d9f8
commit 690fe019f0
21 changed files with 1072 additions and 44 deletions

View File

@@ -3405,11 +3405,13 @@ class CompilationConfig(BaseModel):
- enable_fusion: whether to enable the custom fusion pass.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
- enable_sequence_parallelism: whether to enable sequence parallelism.
"""
dump_graph_stages: list[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_noop: bool = True
enable_sequence_parallelism: bool = False
def uuid(self):
"""
@@ -3418,7 +3420,8 @@ class CompilationConfig(BaseModel):
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
"enable_sequence_parallelism"})
return InductorPass.hash_dict(dict_)
def model_post_init(self, __context: Any) -> None:
@@ -3840,6 +3843,8 @@ class VllmConfig:
if self.compilation_config is None:
self.compilation_config = CompilationConfig()
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if envs.VLLM_USE_V1 and self.model_config is not None and \
not self.model_config.enforce_eager:
# NOTE(woosuk): Currently, we use inductor because the piecewise
@@ -3847,7 +3852,8 @@ class VllmConfig:
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
# FIXME(rob): Add function to set all of these.
self.compilation_config.custom_ops = ["none"]
if not self.compilation_config.custom_ops:
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.use_inductor = True
self.compilation_config.cudagraph_num_of_warmups = 1
@@ -3856,6 +3862,18 @@ class VllmConfig:
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()
if self.parallel_config is not None and \
self.parallel_config.tensor_parallel_size > 1 and \
self.parallel_config.pipeline_parallel_size > 1 and \
self.compilation_config is not None and \
self.compilation_config.pass_config is not None and \
self.compilation_config.pass_config.enable_sequence_parallelism:
logger.warning_once(
"Sequence parallelism is not supported with pipeline "
"parallelism. Disabling sequence parallelism.")
self.compilation_config.pass_config.\
enable_sequence_parallelism = False
self._set_cudagraph_sizes()
if self.cache_config is not None and \
@@ -3895,6 +3913,26 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]
def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
removed_sizes = [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size != 0
]
if removed_sizes:
logger.warning(
"Batch sizes %s are removed because they are not "
"multiple of tp_size %d when "
"sequence parallelism is enabled", removed_sizes,
self.parallel_config.tensor_parallel_size)
return [
size for size in possible_sizes
if size % self.parallel_config.tensor_parallel_size == 0
]
def _set_cudagraph_sizes(self):
"""
cudagraph batchsize padding logic:
@@ -3932,6 +3970,11 @@ class VllmConfig:
not self.model_config.enforce_eager:
possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)]
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
possible_sizes = self.update_sizes_for_sequence_parallelism(
possible_sizes)
# find the minimum size that is larger than max_num_seqs,
# which then becomes the max_batchsize_to_capture
larger_sizes = [
@@ -3955,6 +3998,11 @@ class VllmConfig:
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
if self.parallel_config.tensor_parallel_size > 1 and \
self.compilation_config.pass_config.enable_sequence_parallelism:
batch_size_capture_list = \
self.update_sizes_for_sequence_parallelism(batch_size_capture_list)
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list