[torch.compile] Sequence Parallelism threshold compile ranges (#28672)

Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jason Li
2026-02-25 21:00:12 -08:00
committed by GitHub
parent 4171ff6dd9
commit 9d37941017
8 changed files with 524 additions and 32 deletions

View File

@@ -118,7 +118,9 @@ class PassConfig:
eliminate_noops: bool = Field(default=True)
"""Eliminate no-op ops."""
enable_sp: bool = Field(default=None)
"""Enable sequence parallelism."""
"""Enable sequence parallelism. Requires TP>1. Automatically disabled
if the model's hidden_size is too small for SP to be beneficial
(threshold is device-capability dependent)."""
fuse_gemm_comms: bool = Field(default=None)
"""Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None)
@@ -155,6 +157,11 @@ class PassConfig:
8: 1, # 1MB
},
}, where key is the device capability"""
sp_min_token_num: int | None = None
"""The minimum number of tokens above which vllm should use
sequence parallelism. Specified as an integer token count.
Unspecified will fallback to default values which are compute
capability and world size dependent."""
# TODO(luka) better pass enabling system.

View File

@@ -853,8 +853,33 @@ class VllmConfig:
logger.warning("Sequence Parallelism requires TP>1, disabling")
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
else:
# Compute SP threshold early; disable if None (model too
# small) before +rms_norm gets forced into custom_ops.
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
)
elif "-rms_norm" in self.compilation_config.custom_ops:
tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size
)
if pass_config.sp_min_token_num is None:
logger.warning(
"Model hidden_size too small for the SP "
"threshold heuristic, disabling. To force SP, "
"set pass_config.sp_min_token_num manually."
)
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
if self.compilation_config.pass_config.enable_sp:
if "-rms_norm" in self.compilation_config.custom_ops:
logger.warning(
"RMS norm force disabled, sequence parallelism might break"
)
@@ -1456,6 +1481,36 @@ class VllmConfig:
"allreduce-rms fusion will be enabled for all num_tokens."
)
# Add the compile ranges for sequence parallelism
if compilation_config.pass_config.enable_sp:
pass_config = compilation_config.pass_config
# Calculate min_token_num if not explicitly provided
# User override works regardless of hidden_size
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
)
tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size
)
min_token_num = pass_config.sp_min_token_num
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if min_token_num is not None and (
max_num_batched_tokens is not None
and min_token_num < max_num_batched_tokens
and min_token_num > 1
):
# Add split point at min_token_num - 1 to ensure SP applies
# starting from min_token_num
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
computed_compile_ranges_split_points.append(min_token_num - 1)
if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = (
compilation_config.pass_config.rope_kvcache_fusion_max_token_num