[torch.compile] Sequence Parallelism threshold compile ranges (#28672)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: Jason Li <jasonlizhengjian@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -118,7 +118,9 @@ class PassConfig:
|
||||
eliminate_noops: bool = Field(default=True)
|
||||
"""Eliminate no-op ops."""
|
||||
enable_sp: bool = Field(default=None)
|
||||
"""Enable sequence parallelism."""
|
||||
"""Enable sequence parallelism. Requires TP>1. Automatically disabled
|
||||
if the model's hidden_size is too small for SP to be beneficial
|
||||
(threshold is device-capability dependent)."""
|
||||
fuse_gemm_comms: bool = Field(default=None)
|
||||
"""Enable async TP."""
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
@@ -155,6 +157,11 @@ class PassConfig:
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
sp_min_token_num: int | None = None
|
||||
"""The minimum number of tokens above which vllm should use
|
||||
sequence parallelism. Specified as an integer token count.
|
||||
Unspecified will fallback to default values which are compute
|
||||
capability and world size dependent."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
|
||||
@@ -853,8 +853,33 @@ class VllmConfig:
|
||||
logger.warning("Sequence Parallelism requires TP>1, disabling")
|
||||
self.compilation_config.pass_config.enable_sp = False
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
else:
|
||||
# Compute SP threshold early; disable if None (model too
|
||||
# small) before +rms_norm gets forced into custom_ops.
|
||||
pass_config = self.compilation_config.pass_config
|
||||
if pass_config.sp_min_token_num is None:
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
get_sequence_parallelism_threshold,
|
||||
)
|
||||
|
||||
elif "-rms_norm" in self.compilation_config.custom_ops:
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
hidden_size = self.model_config.get_hidden_size()
|
||||
element_size = self.model_config.dtype.itemsize
|
||||
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
|
||||
hidden_size, tp_size, element_size
|
||||
)
|
||||
|
||||
if pass_config.sp_min_token_num is None:
|
||||
logger.warning(
|
||||
"Model hidden_size too small for the SP "
|
||||
"threshold heuristic, disabling. To force SP, "
|
||||
"set pass_config.sp_min_token_num manually."
|
||||
)
|
||||
self.compilation_config.pass_config.enable_sp = False
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
if "-rms_norm" in self.compilation_config.custom_ops:
|
||||
logger.warning(
|
||||
"RMS norm force disabled, sequence parallelism might break"
|
||||
)
|
||||
@@ -1456,6 +1481,36 @@ class VllmConfig:
|
||||
"allreduce-rms fusion will be enabled for all num_tokens."
|
||||
)
|
||||
|
||||
# Add the compile ranges for sequence parallelism
|
||||
if compilation_config.pass_config.enable_sp:
|
||||
pass_config = compilation_config.pass_config
|
||||
|
||||
# Calculate min_token_num if not explicitly provided
|
||||
# User override works regardless of hidden_size
|
||||
if pass_config.sp_min_token_num is None:
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
get_sequence_parallelism_threshold,
|
||||
)
|
||||
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
hidden_size = self.model_config.get_hidden_size()
|
||||
element_size = self.model_config.dtype.itemsize
|
||||
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
|
||||
hidden_size, tp_size, element_size
|
||||
)
|
||||
|
||||
min_token_num = pass_config.sp_min_token_num
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
if min_token_num is not None and (
|
||||
max_num_batched_tokens is not None
|
||||
and min_token_num < max_num_batched_tokens
|
||||
and min_token_num > 1
|
||||
):
|
||||
# Add split point at min_token_num - 1 to ensure SP applies
|
||||
# starting from min_token_num
|
||||
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
|
||||
computed_compile_ranges_split_points.append(min_token_num - 1)
|
||||
|
||||
if compilation_config.pass_config.fuse_rope_kvcache:
|
||||
max_token_num = (
|
||||
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
|
||||
|
||||
Reference in New Issue
Block a user