[torch.compile] Sequence Parallelism threshold compile ranges (#28672)

Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jason Li
2026-02-25 21:00:12 -08:00
committed by GitHub
parent 4171ff6dd9
commit 9d37941017
8 changed files with 524 additions and 32 deletions

34
tests/compile/conftest.py Normal file
View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from vllm.platforms.interface import DeviceCapability
@pytest.fixture
def mock_cuda_platform():
"""
Fixture that returns a factory for creating mocked CUDA platforms.
Usage:
def test_something(mock_cuda_platform):
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
# test code
"""
@contextmanager
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
mock_platform = MagicMock()
mock_platform.is_cuda.return_value = is_cuda
if capability is not None:
mock_platform.get_device_capability.return_value = DeviceCapability(
*capability
)
with patch("vllm.platforms.current_platform", mock_platform):
yield mock_platform
return _mock_platform

View File

@@ -94,7 +94,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
run_model(full_compilation_config, model_name, **model_kwargs) run_model(full_compilation_config, model_name, **model_kwargs)
num_compile_ranges = len(full_compilation_config.get_compile_ranges()) num_compile_ranges = len(full_compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2] assert num_compile_ranges in [1, 2, 3]
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}") print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
print("Fusion results:") print("Fusion results:")
@@ -107,12 +107,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# Now check the matches # Now check the matches
for match_name in matches_check: for match_name in matches_check:
num_ranges_activated = (
1 if match_name == "ar_rms_fusion" else num_compile_ranges
)
n_expected = tp_size * num_ranges_activated
log_matches = list(int(ms) for ms in log_matches_dict[match_name]) log_matches = list(int(ms) for ms in log_matches_dict[match_name])
# AR+RMS skips the largest range; SP skips the smallest.
# When both are enabled, AR+RMS activation count is
# model-dependent (hidden_size affects threshold), so derive
# from log data.
if (
match_name == "ar_rms_fusion"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
assert (
len(log_matches) >= tp_size and len(log_matches) % tp_size == 0
), (
f"Expected multiple of {tp_size} ar_rms log entries, "
f"found {len(log_matches)}"
)
num_ranges_activated = len(log_matches) // tp_size
elif (
match_name in ("ar_rms_fusion", "sequence_parallel")
and num_compile_ranges >= 2
):
num_ranges_activated = num_compile_ranges - 1
else:
num_ranges_activated = num_compile_ranges
n_expected = tp_size * num_ranges_activated
assert len(log_matches) == n_expected, ( assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} " f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}" f"(found {len(log_matches)}) in:\n {log_holder.text}"
@@ -122,8 +143,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check: if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
# AR+rms+quant takes precedence over rms+quant if activated. # AR+rms+quant takes precedence over rms+quant if activated.
# That means we get full matching where ar+rms+quant was not activated, # That means we get full matching where ar+rms+quant was not
# and less where it was # activated, and less where it was (only the smallest range).
assert sum(m == expected_matches for m in log_matches) == tp_size * ( assert sum(m == expected_matches for m in log_matches) == tp_size * (
num_ranges_activated - 1 num_ranges_activated - 1
), "Expecting full rms+quant fusion where ar+rms+quant not activated" ), "Expecting full rms+quant fusion where ar+rms+quant not activated"
@@ -135,6 +156,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"Expecting at least {expected_matches - matches.ar_rms_fusion} " f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
f"where ar+rms+quant was activated" f"where ar+rms+quant was activated"
) )
elif (
match_name == "async_tp"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
# AsyncTP only finds patterns on ranges where SP ran.
n_sp_ranges = num_compile_ranges - 1
assert (
sum(m == expected_matches for m in log_matches)
== tp_size * n_sp_ranges
), (
f"Expecting {expected_matches} async_tp on "
f"{tp_size * n_sp_ranges} SP-range entries, "
f"found: {log_matches}"
)
assert sum(m == 0 for m in log_matches) == tp_size, (
f"Expecting 0 async_tp on {tp_size} small-range entries "
f"(no SP), found: {log_matches}"
)
elif (
match_name == "ar_rms_fusion"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
# SP consumes allreduce patterns first, so AR+RMS finds
# full matches only on the smallest range (no SP).
assert sum(m == expected_matches for m in log_matches) == tp_size, (
f"Expecting {expected_matches} ar_rms on "
f"{tp_size} small-range entries, found: {log_matches}"
)
assert sum(m == 0 for m in log_matches) == tp_size * (
num_ranges_activated - 1
), (
f"Expecting 0 ar_rms on "
f"{tp_size * (num_ranges_activated - 1)} large-range "
f"entries (SP took precedence), found: {log_matches}"
)
else: else:
expected_matches_list = [expected_matches] * n_expected expected_matches_list = [expected_matches] * n_expected
assert sorted(log_matches) == expected_matches_list, ( assert sorted(log_matches) == expected_matches_list, (
@@ -142,7 +200,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"found: {sorted(log_matches)}" f"found: {sorted(log_matches)}"
) )
if match_name == "ar_rms_fusion": if match_name == "ar_rms_fusion" and num_compile_ranges >= 2:
log_matches = re.findall( log_matches = re.findall(
r"pass_manager.py:\d+] Skipping " r"pass_manager.py:\d+] Skipping "
r".*AllReduceFusionPass.* with compile range", r".*AllReduceFusionPass.* with compile range",
@@ -155,4 +213,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"(found {len(log_matches)}) in:\n {log_holder.text}" f"(found {len(log_matches)}) in:\n {log_holder.text}"
) )
if match_name == "sequence_parallel" and num_compile_ranges >= 2:
log_matches = re.findall(
r"pass_manager.py:\d+] Skipping "
r".*SequenceParallelismPass.* with compile range",
log_holder.text,
)
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
assert len(log_matches) == n_expected, (
f'Could not find {n_expected} "Skipping SequenceParallelismPass" '
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
return run return run

View File

@@ -66,6 +66,9 @@ def test_tp2_async_tp_fp8_fusions(
enable_qk_norm_rope_fusion=True, enable_qk_norm_rope_fusion=True,
enable_sp=True, enable_sp=True,
fuse_gemm_comms=True, fuse_gemm_comms=True,
fuse_allreduce_rms=False,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
), ),
) )
@@ -123,6 +126,9 @@ def test_tp2_async_tp_fusions(
enable_qk_norm_rope_fusion=True, enable_qk_norm_rope_fusion=True,
enable_sp=True, enable_sp=True,
fuse_gemm_comms=True, fuse_gemm_comms=True,
fuse_allreduce_rms=False,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
), ),
) )
@@ -141,3 +147,130 @@ def test_tp2_async_tp_fusions(
matches_check, matches_check,
tp_size=2, tp_size=2,
) )
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp8, llama4_scout_fp8],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
def test_tp2_sp_ar_rms_fp8_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
model_kwargs: dict,
hf_overrides: Callable[[int], dict],
attn_backend: AttentionBackendCase,
n_layers: int,
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
monkeypatch,
):
matches = matches_fn(n_layers)
if is_blackwell():
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
pass_config=PassConfig(
fuse_norm_quant=True,
fuse_act_quant=True,
fuse_attn_quant=True,
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
matches_check = [
"rms_quant_fusion",
"act_quant_fusion",
"norm_rope_fusion",
"attn_quant_fusion",
"ar_rms_fusion",
"sequence_parallel",
"async_tp",
]
run_e2e_fusion_test(
model_name,
matches,
model_kwargs,
attn_backend,
compilation_config,
matches_check,
tp_size=2,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
def test_tp2_sp_ar_rms_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
model_kwargs: dict,
hf_overrides: Callable[[int], dict],
attn_backend: AttentionBackendCase,
n_layers: int,
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
):
matches = matches_fn(n_layers)
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
pass_config=PassConfig(
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
matches_check = [
"norm_rope_fusion",
"ar_rms_fusion",
"sequence_parallel",
"async_tp",
]
run_e2e_fusion_test(
model_name,
matches,
model_kwargs,
attn_backend,
compilation_config,
matches_check,
tp_size=2,
)

View File

@@ -421,6 +421,7 @@ def test_cudagraph_sizes_post_init(
fuse_norm_quant=True, fuse_norm_quant=True,
fuse_act_quant=True, fuse_act_quant=True,
eliminate_noops=True, eliminate_noops=True,
sp_min_token_num=512 if enable_sp else None,
), ),
cudagraph_mode=cudagraph_mode, cudagraph_mode=cudagraph_mode,
) )

View File

@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.compilation.passes.fusion.sequence_parallelism import (
SP_MIN_HIDDEN_SIZE,
SP_MIN_PER_GPU_SIZE_MB,
get_sequence_parallelism_threshold,
)
class TestGetSequenceParallelismThreshold:
"""Tests for get_sequence_parallelism_threshold function."""
def test_non_cuda_returns_none(self, mock_cuda_platform):
"""Non-CUDA platforms should return None."""
with mock_cuda_platform(is_cuda=False):
result = get_sequence_parallelism_threshold(
hidden_size=8192, tp_size=2, element_size=2
)
assert result is None
def test_unsupported_device_capability_returns_none(self, mock_cuda_platform):
"""Unsupported device capabilities (e.g., sm80) should return None."""
with mock_cuda_platform(capability=(8, 0)):
result = get_sequence_parallelism_threshold(
hidden_size=8192, tp_size=2, element_size=2
)
assert result is None
def test_small_hidden_size_returns_none(self, mock_cuda_platform):
"""H100 with hidden_size below threshold should return None."""
with mock_cuda_platform(capability=(9, 0)):
result = get_sequence_parallelism_threshold(
hidden_size=4096,
tp_size=2,
element_size=2, # 4096 < 8192
)
assert result is None
def test_h100_large_model_returns_threshold(self, mock_cuda_platform):
"""H100 with large enough hidden_size should return calculated threshold."""
with mock_cuda_platform(capability=(9, 0)):
hidden_size = 8192
tp_size = 2
element_size = 2 # float16/bfloat16
result = get_sequence_parallelism_threshold(
hidden_size=hidden_size,
tp_size=tp_size,
element_size=element_size,
)
# Verify calculation: (8 * 2 * 1024 * 1024) // (8192 * 2) = 1024
MiB = 1024 * 1024
expected = int(
(SP_MIN_PER_GPU_SIZE_MB[90] * tp_size * MiB)
// (hidden_size * element_size)
)
assert result == expected
assert result == 1024
@pytest.mark.parametrize(
"hidden_size,tp_size,element_size,expected",
[
# Boundary: exactly at min hidden size threshold, tp_size=1
# (8 * 1 * 1024 * 1024) // (8192 * 2) = 512
(8192, 1, 2, 512),
# Larger hidden size reduces token threshold
# (8 * 1 * 1024 * 1024) // (16384 * 2) = 256
(16384, 1, 2, 256),
# Larger tp_size increases token threshold
# (8 * 4 * 1024 * 1024) // (8192 * 2) = 2048
(8192, 4, 2, 2048),
# Larger element_size (fp32) reduces token threshold
# (8 * 2 * 1024 * 1024) // (8192 * 4) = 512
(8192, 2, 4, 512),
],
)
def test_threshold_calculation_variations(
self, mock_cuda_platform, hidden_size, tp_size, element_size, expected
):
"""Test threshold calculation with various parameter combinations."""
with mock_cuda_platform(capability=(9, 0)):
result = get_sequence_parallelism_threshold(
hidden_size=hidden_size,
tp_size=tp_size,
element_size=element_size,
)
assert result == expected
def test_hidden_size_boundary(self, mock_cuda_platform):
"""Test behavior at the exact hidden_size boundary."""
with mock_cuda_platform(capability=(9, 0)):
# Just below threshold
result = get_sequence_parallelism_threshold(
hidden_size=SP_MIN_HIDDEN_SIZE[90] - 1,
tp_size=2,
element_size=2,
)
assert result is None
# Exactly at threshold
result = get_sequence_parallelism_threshold(
hidden_size=SP_MIN_HIDDEN_SIZE[90],
tp_size=2,
element_size=2,
)
assert result is not None

View File

@@ -27,6 +27,63 @@ from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNo
logger = init_logger(__name__) logger = init_logger(__name__)
# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
90: 8192, # H100: only for models with hidden_size >= 8192
}
# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
90: 8, # 8MB per GPU for H100
}
def get_sequence_parallelism_threshold(
hidden_size: int,
tp_size: int,
element_size: int,
) -> int | None:
"""
Calculate the minimum token threshold for applying sequence parallelism.
Returns None if sequence parallelism should not be applied based on model size.
Branching logic based on device capability:
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
- If not, returns None (SP disabled for small models on this device)
- If yes, calculates threshold based on per-GPU size
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
(hidden_size * element_size)
"""
from vllm.platforms import current_platform
if not current_platform.is_cuda():
return None
capability = current_platform.get_device_capability()
if capability is None:
return None
device_capability = capability.to_int()
# Check if device has configured thresholds
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
if min_hidden_size is None or min_per_gpu_size_mb is None:
return None
# Only apply sequence parallelism for models meeting the size threshold
if hidden_size < min_hidden_size:
return None
MiB = 1024 * 1024
min_size = min_per_gpu_size_mb * MiB * tp_size
return int(min_size // (hidden_size * element_size))
def get_first_out_wrapper( def get_first_out_wrapper(
fn: Callable[..., Sequence[torch.Tensor]], fn: Callable[..., Sequence[torch.Tensor]],
@@ -309,6 +366,23 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
def __init__(self, config: VllmConfig) -> None: def __init__(self, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
# Get min_token_num threshold
# Read min_token_num from config (calculated during config init)
self.min_token_num = None
if config.model_config is not None:
pass_config = config.compilation_config.pass_config
self.min_token_num = pass_config.sp_min_token_num
if self.min_token_num is not None:
# Take the min to avoid exceeding max_num_batched_tokens
max_batched = config.scheduler_config.max_num_batched_tokens
if max_batched is not None:
self.min_token_num = min(self.min_token_num, max_batched)
logger.debug_once(
f"Sequence parallelism min token threshold: {self.min_token_num}",
scope="global",
)
# Used to clean up redundant views created temporarily # Used to clean up redundant views created temporarily
# to circumvent residual shape change issues # to circumvent residual shape change issues
self.noop_cleanup = NoOpEliminationPass(config) self.noop_cleanup = NoOpEliminationPass(config)
@@ -339,29 +413,36 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool: def is_applicable_for_range(self, compile_range: Range) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm """
# needs to be split along the sequence dimension. However, this dimension Determines if sequence parallelism should be applied for the given
# is symbolic during piecewise compilation, and splitting symbolic shapes compile range.
# is not supported.
# SP is only beneficial for larger batch sizes where the communication
# This pass is therefore only applied when the sequence dimension is overhead is amortized. For small batches, the overhead of splitting
# concrete: and gathering tensors across TP ranks outweighs the benefits.
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
# For this case we always pad num_tokens to be a multiple of Returns False (SP disabled) when:
# tensor_parallel_size, so there's no need to check shape % tp_size == 0. - Using piecewise compilation with non-concrete or TP-indivisible sizes
# 2. For specific shape provided during compilation (e.g., from - min_token_num is None (SP disabled for this device/config)
# `compile_sizes`), which must be divisible by the tensor-parallel - The compile range starts below the minimum token threshold
# size. """
# For piecewise compilation (not using inductor graph partition),
# we need concrete sizes that are divisible by TP for correct splitting
if ( if (
not self.compilation_config.splitting_ops not self.compilation_config.use_inductor_graph_partition
or self.compilation_config.use_inductor_graph_partition and self.compilation_config.splitting_ops
): ):
return True tp_size = get_tensor_model_parallel_world_size()
tp_size = get_tensor_model_parallel_world_size() if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
result: bool = (compile_range.is_single_size()) and ( return False
compile_range.end % tp_size == 0
) # min_token_num is None when SP is disabled for this device/config
return result # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
if self.min_token_num is None:
return False
# Only apply SP when batch size meets the minimum threshold
return compile_range.start >= self.min_token_num
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None: def __call__(self, graph: fx.Graph) -> None:

View File

@@ -118,7 +118,9 @@ class PassConfig:
eliminate_noops: bool = Field(default=True) eliminate_noops: bool = Field(default=True)
"""Eliminate no-op ops.""" """Eliminate no-op ops."""
enable_sp: bool = Field(default=None) enable_sp: bool = Field(default=None)
"""Enable sequence parallelism.""" """Enable sequence parallelism. Requires TP>1. Automatically disabled
if the model's hidden_size is too small for SP to be beneficial
(threshold is device-capability dependent)."""
fuse_gemm_comms: bool = Field(default=None) fuse_gemm_comms: bool = Field(default=None)
"""Enable async TP.""" """Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None) fuse_allreduce_rms: bool = Field(default=None)
@@ -155,6 +157,11 @@ class PassConfig:
8: 1, # 1MB 8: 1, # 1MB
}, },
}, where key is the device capability""" }, where key is the device capability"""
sp_min_token_num: int | None = None
"""The minimum number of tokens above which vllm should use
sequence parallelism. Specified as an integer token count.
Unspecified will fallback to default values which are compute
capability and world size dependent."""
# TODO(luka) better pass enabling system. # TODO(luka) better pass enabling system.

View File

@@ -853,8 +853,33 @@ class VllmConfig:
logger.warning("Sequence Parallelism requires TP>1, disabling") logger.warning("Sequence Parallelism requires TP>1, disabling")
self.compilation_config.pass_config.enable_sp = False self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False self.compilation_config.pass_config.fuse_gemm_comms = False
else:
# Compute SP threshold early; disable if None (model too
# small) before +rms_norm gets forced into custom_ops.
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
)
elif "-rms_norm" in self.compilation_config.custom_ops: tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size
)
if pass_config.sp_min_token_num is None:
logger.warning(
"Model hidden_size too small for the SP "
"threshold heuristic, disabling. To force SP, "
"set pass_config.sp_min_token_num manually."
)
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
if self.compilation_config.pass_config.enable_sp:
if "-rms_norm" in self.compilation_config.custom_ops:
logger.warning( logger.warning(
"RMS norm force disabled, sequence parallelism might break" "RMS norm force disabled, sequence parallelism might break"
) )
@@ -1456,6 +1481,36 @@ class VllmConfig:
"allreduce-rms fusion will be enabled for all num_tokens." "allreduce-rms fusion will be enabled for all num_tokens."
) )
# Add the compile ranges for sequence parallelism
if compilation_config.pass_config.enable_sp:
pass_config = compilation_config.pass_config
# Calculate min_token_num if not explicitly provided
# User override works regardless of hidden_size
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
)
tp_size = self.parallel_config.tensor_parallel_size
hidden_size = self.model_config.get_hidden_size()
element_size = self.model_config.dtype.itemsize
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
hidden_size, tp_size, element_size
)
min_token_num = pass_config.sp_min_token_num
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if min_token_num is not None and (
max_num_batched_tokens is not None
and min_token_num < max_num_batched_tokens
and min_token_num > 1
):
# Add split point at min_token_num - 1 to ensure SP applies
# starting from min_token_num
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
computed_compile_ranges_split_points.append(min_token_num - 1)
if compilation_config.pass_config.fuse_rope_kvcache: if compilation_config.pass_config.fuse_rope_kvcache:
max_token_num = ( max_token_num = (
compilation_config.pass_config.rope_kvcache_fusion_max_token_num compilation_config.pass_config.rope_kvcache_fusion_max_token_num