[torch.compile] Sequence Parallelism threshold compile ranges (#28672)

Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jason Li
2026-02-25 21:00:12 -08:00
committed by GitHub
parent 4171ff6dd9
commit 9d37941017
8 changed files with 524 additions and 32 deletions

View File

@@ -853,8 +853,33 @@ class VllmConfig:
logger.warning("Sequence Parallelism requires TP>1, disabling")
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
else:
# Compute SP threshold early; disable if None (model too
# small) before +rms_norm gets forced into custom_ops.
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
)
elif "-rms_norm" in self.compilation_config.custom_ops:
tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size
)
if pass_config.sp_min_token_num is None:
logger.warning(
"Model hidden_size too small for the SP "
"threshold heuristic, disabling. To force SP, "
"set pass_config.sp_min_token_num manually."
)
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
if self.compilation_config.pass_config.enable_sp:
if "-rms_norm" in self.compilation_config.custom_ops:
logger.warning(
"RMS norm force disabled, sequence parallelism might break"
)
@@ -1456,6 +1481,36 @@ class VllmConfig:
"allreduce-rms fusion will be enabled for all num_tokens."
)
# Add the compile ranges for sequence parallelism
if compilation_config.pass_config.enable_sp:
pass_config = compilation_config.pass_config
# Calculate min_token_num if not explicitly provided
# User override works regardless of hidden_size
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
)
tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size
)
min_token_num = pass_config.sp_min_token_num
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if min_token_num is not None and (
max_num_batched_tokens is not None
and min_token_num < max_num_batched_tokens
and min_token_num > 1
):
# Add split point at min_token_num - 1 to ensure SP applies
# starting from min_token_num
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
computed_compile_ranges_split_points.append(min_token_num - 1)
if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = (
compilation_config.pass_config.rope_kvcache_fusion_max_token_num