From 9d379410179b649f4e7651940debc35c4ac7c0a5 Mon Sep 17 00:00:00 2001 From: Jason Li Date: Wed, 25 Feb 2026 21:00:12 -0800 Subject: [PATCH] [torch.compile] Sequence Parallelism threshold compile ranges (#28672) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: jasonlizhengjian Signed-off-by: Jason Li Co-authored-by: Claude Opus 4.6 Co-authored-by: Luka Govedič --- tests/compile/conftest.py | 34 +++++ tests/compile/fusions_e2e/conftest.py | 89 ++++++++++-- .../compile/fusions_e2e/test_tp2_async_tp.py | 133 ++++++++++++++++++ tests/compile/test_config.py | 1 + .../test_sequence_parallelism_threshold.py | 110 +++++++++++++++ .../passes/fusion/sequence_parallelism.py | 123 +++++++++++++--- vllm/config/compilation.py | 9 +- vllm/config/vllm.py | 57 +++++++- 8 files changed, 524 insertions(+), 32 deletions(-) create mode 100644 tests/compile/conftest.py create mode 100644 tests/compile/test_sequence_parallelism_threshold.py diff --git a/tests/compile/conftest.py b/tests/compile/conftest.py new file mode 100644 index 000000000..6aafac7bc --- /dev/null +++ b/tests/compile/conftest.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest + +from vllm.platforms.interface import DeviceCapability + + +@pytest.fixture +def mock_cuda_platform(): + """ + Fixture that returns a factory for creating mocked CUDA platforms. + + Usage: + def test_something(mock_cuda_platform): + with mock_cuda_platform(is_cuda=True, capability=(9, 0)): + # test code + """ + + @contextmanager + def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None): + mock_platform = MagicMock() + mock_platform.is_cuda.return_value = is_cuda + if capability is not None: + mock_platform.get_device_capability.return_value = DeviceCapability( + *capability + ) + with patch("vllm.platforms.current_platform", mock_platform): + yield mock_platform + + return _mock_platform diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 1d9f6cda9..40b4de57f 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -94,7 +94,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): run_model(full_compilation_config, model_name, **model_kwargs) num_compile_ranges = len(full_compilation_config.get_compile_ranges()) - assert num_compile_ranges in [1, 2] + assert num_compile_ranges in [1, 2, 3] print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}") print("Fusion results:") @@ -107,12 +107,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): # Now check the matches for match_name in matches_check: - num_ranges_activated = ( - 1 if match_name == "ar_rms_fusion" else num_compile_ranges - ) - n_expected = tp_size * num_ranges_activated - log_matches = list(int(ms) for ms in log_matches_dict[match_name]) + + # AR+RMS skips the largest range; SP skips the smallest. + # When both are enabled, AR+RMS activation count is + # model-dependent (hidden_size affects threshold), so derive + # from log data. + if ( + match_name == "ar_rms_fusion" + and "sequence_parallel" in matches_check + and num_compile_ranges >= 2 + ): + assert ( + len(log_matches) >= tp_size and len(log_matches) % tp_size == 0 + ), ( + f"Expected multiple of {tp_size} ar_rms log entries, " + f"found {len(log_matches)}" + ) + num_ranges_activated = len(log_matches) // tp_size + elif ( + match_name in ("ar_rms_fusion", "sequence_parallel") + and num_compile_ranges >= 2 + ): + num_ranges_activated = num_compile_ranges - 1 + else: + num_ranges_activated = num_compile_ranges + + n_expected = tp_size * num_ranges_activated assert len(log_matches) == n_expected, ( f"Could not find {n_expected} {match_name} " f"(found {len(log_matches)}) in:\n {log_holder.text}" @@ -122,8 +143,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check: # AR+rms+quant takes precedence over rms+quant if activated. - # That means we get full matching where ar+rms+quant was not activated, - # and less where it was + # That means we get full matching where ar+rms+quant was not + # activated, and less where it was (only the smallest range). assert sum(m == expected_matches for m in log_matches) == tp_size * ( num_ranges_activated - 1 ), "Expecting full rms+quant fusion where ar+rms+quant not activated" @@ -135,6 +156,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): f"Expecting at least {expected_matches - matches.ar_rms_fusion} " f"where ar+rms+quant was activated" ) + elif ( + match_name == "async_tp" + and "sequence_parallel" in matches_check + and num_compile_ranges >= 2 + ): + # AsyncTP only finds patterns on ranges where SP ran. + n_sp_ranges = num_compile_ranges - 1 + assert ( + sum(m == expected_matches for m in log_matches) + == tp_size * n_sp_ranges + ), ( + f"Expecting {expected_matches} async_tp on " + f"{tp_size * n_sp_ranges} SP-range entries, " + f"found: {log_matches}" + ) + assert sum(m == 0 for m in log_matches) == tp_size, ( + f"Expecting 0 async_tp on {tp_size} small-range entries " + f"(no SP), found: {log_matches}" + ) + elif ( + match_name == "ar_rms_fusion" + and "sequence_parallel" in matches_check + and num_compile_ranges >= 2 + ): + # SP consumes allreduce patterns first, so AR+RMS finds + # full matches only on the smallest range (no SP). + assert sum(m == expected_matches for m in log_matches) == tp_size, ( + f"Expecting {expected_matches} ar_rms on " + f"{tp_size} small-range entries, found: {log_matches}" + ) + assert sum(m == 0 for m in log_matches) == tp_size * ( + num_ranges_activated - 1 + ), ( + f"Expecting 0 ar_rms on " + f"{tp_size * (num_ranges_activated - 1)} large-range " + f"entries (SP took precedence), found: {log_matches}" + ) else: expected_matches_list = [expected_matches] * n_expected assert sorted(log_matches) == expected_matches_list, ( @@ -142,7 +200,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): f"found: {sorted(log_matches)}" ) - if match_name == "ar_rms_fusion": + if match_name == "ar_rms_fusion" and num_compile_ranges >= 2: log_matches = re.findall( r"pass_manager.py:\d+] Skipping " r".*AllReduceFusionPass.* with compile range", @@ -155,4 +213,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): f"(found {len(log_matches)}) in:\n {log_holder.text}" ) + if match_name == "sequence_parallel" and num_compile_ranges >= 2: + log_matches = re.findall( + r"pass_manager.py:\d+] Skipping " + r".*SequenceParallelismPass.* with compile range", + log_holder.text, + ) + + n_expected = tp_size * (num_compile_ranges - num_ranges_activated) + assert len(log_matches) == n_expected, ( + f'Could not find {n_expected} "Skipping SequenceParallelismPass" ' + f"(found {len(log_matches)}) in:\n {log_holder.text}" + ) + return run diff --git a/tests/compile/fusions_e2e/test_tp2_async_tp.py b/tests/compile/fusions_e2e/test_tp2_async_tp.py index 4769ca1e0..921839ea0 100644 --- a/tests/compile/fusions_e2e/test_tp2_async_tp.py +++ b/tests/compile/fusions_e2e/test_tp2_async_tp.py @@ -66,6 +66,9 @@ def test_tp2_async_tp_fp8_fusions( enable_qk_norm_rope_fusion=True, enable_sp=True, fuse_gemm_comms=True, + fuse_allreduce_rms=False, + # Override threshold for testing (models have small hidden_size) + sp_min_token_num=512, ), ) @@ -123,6 +126,9 @@ def test_tp2_async_tp_fusions( enable_qk_norm_rope_fusion=True, enable_sp=True, fuse_gemm_comms=True, + fuse_allreduce_rms=False, + # Override threshold for testing (models have small hidden_size) + sp_min_token_num=512, ), ) @@ -141,3 +147,130 @@ def test_tp2_async_tp_fusions( matches_check, tp_size=2, ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, matches_fn, model_kwargs, hf_overrides", + [llama3_8b_fp8, llama4_scout_fp8], +) +@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN]) +@pytest.mark.parametrize("n_layers", [4]) +@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_tp2_sp_ar_rms_fp8_fusions( + model_name: str, + matches_fn: Callable[[int], Matches], + model_kwargs: dict, + hf_overrides: Callable[[int], dict], + attn_backend: AttentionBackendCase, + n_layers: int, + custom_ops: str, + inductor_graph_partition: bool, + run_e2e_fusion_test, + monkeypatch, +): + matches = matches_fn(n_layers) + + if is_blackwell(): + # Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns + monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel") + + # Reduce size of model and skip weight loading time + model_kwargs["hf_overrides"] = hf_overrides(n_layers) + model_kwargs["load_format"] = "dummy" + model_kwargs["max_model_len"] = 1024 + + compilation_config = dict( + use_inductor_graph_partition=inductor_graph_partition, + custom_ops=custom_ops.split(","), + pass_config=PassConfig( + fuse_norm_quant=True, + fuse_act_quant=True, + fuse_attn_quant=True, + enable_qk_norm_rope_fusion=True, + enable_sp=True, + fuse_gemm_comms=True, + fuse_allreduce_rms=True, + # Override threshold for testing (models have small hidden_size) + sp_min_token_num=512, + ), + ) + + matches_check = [ + "rms_quant_fusion", + "act_quant_fusion", + "norm_rope_fusion", + "attn_quant_fusion", + "ar_rms_fusion", + "sequence_parallel", + "async_tp", + ] + + run_e2e_fusion_test( + model_name, + matches, + model_kwargs, + attn_backend, + compilation_config, + matches_check, + tp_size=2, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, matches_fn, model_kwargs, hf_overrides", + [llama3_8b, qwen3_a3b], +) +@pytest.mark.parametrize("attn_backend", [TRITON_ATTN]) +@pytest.mark.parametrize("n_layers", [4]) +@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +def test_tp2_sp_ar_rms_fusions( + model_name: str, + matches_fn: Callable[[int], Matches], + model_kwargs: dict, + hf_overrides: Callable[[int], dict], + attn_backend: AttentionBackendCase, + n_layers: int, + custom_ops: str, + inductor_graph_partition: bool, + run_e2e_fusion_test, +): + matches = matches_fn(n_layers) + + # Reduce size of model and skip weight loading time + model_kwargs["hf_overrides"] = hf_overrides(n_layers) + model_kwargs["load_format"] = "dummy" + model_kwargs["max_model_len"] = 1024 + + compilation_config = dict( + use_inductor_graph_partition=inductor_graph_partition, + custom_ops=custom_ops.split(","), + pass_config=PassConfig( + enable_qk_norm_rope_fusion=True, + enable_sp=True, + fuse_gemm_comms=True, + fuse_allreduce_rms=True, + # Override threshold for testing (models have small hidden_size) + sp_min_token_num=512, + ), + ) + + matches_check = [ + "norm_rope_fusion", + "ar_rms_fusion", + "sequence_parallel", + "async_tp", + ] + + run_e2e_fusion_test( + model_name, + matches, + model_kwargs, + attn_backend, + compilation_config, + matches_check, + tp_size=2, + ) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index eb2f0669e..3ba70b6aa 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -421,6 +421,7 @@ def test_cudagraph_sizes_post_init( fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True, + sp_min_token_num=512 if enable_sp else None, ), cudagraph_mode=cudagraph_mode, ) diff --git a/tests/compile/test_sequence_parallelism_threshold.py b/tests/compile/test_sequence_parallelism_threshold.py new file mode 100644 index 000000000..42e374cd9 --- /dev/null +++ b/tests/compile/test_sequence_parallelism_threshold.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.compilation.passes.fusion.sequence_parallelism import ( + SP_MIN_HIDDEN_SIZE, + SP_MIN_PER_GPU_SIZE_MB, + get_sequence_parallelism_threshold, +) + + +class TestGetSequenceParallelismThreshold: + """Tests for get_sequence_parallelism_threshold function.""" + + def test_non_cuda_returns_none(self, mock_cuda_platform): + """Non-CUDA platforms should return None.""" + with mock_cuda_platform(is_cuda=False): + result = get_sequence_parallelism_threshold( + hidden_size=8192, tp_size=2, element_size=2 + ) + assert result is None + + def test_unsupported_device_capability_returns_none(self, mock_cuda_platform): + """Unsupported device capabilities (e.g., sm80) should return None.""" + with mock_cuda_platform(capability=(8, 0)): + result = get_sequence_parallelism_threshold( + hidden_size=8192, tp_size=2, element_size=2 + ) + assert result is None + + def test_small_hidden_size_returns_none(self, mock_cuda_platform): + """H100 with hidden_size below threshold should return None.""" + with mock_cuda_platform(capability=(9, 0)): + result = get_sequence_parallelism_threshold( + hidden_size=4096, + tp_size=2, + element_size=2, # 4096 < 8192 + ) + assert result is None + + def test_h100_large_model_returns_threshold(self, mock_cuda_platform): + """H100 with large enough hidden_size should return calculated threshold.""" + with mock_cuda_platform(capability=(9, 0)): + hidden_size = 8192 + tp_size = 2 + element_size = 2 # float16/bfloat16 + + result = get_sequence_parallelism_threshold( + hidden_size=hidden_size, + tp_size=tp_size, + element_size=element_size, + ) + + # Verify calculation: (8 * 2 * 1024 * 1024) // (8192 * 2) = 1024 + MiB = 1024 * 1024 + expected = int( + (SP_MIN_PER_GPU_SIZE_MB[90] * tp_size * MiB) + // (hidden_size * element_size) + ) + assert result == expected + assert result == 1024 + + @pytest.mark.parametrize( + "hidden_size,tp_size,element_size,expected", + [ + # Boundary: exactly at min hidden size threshold, tp_size=1 + # (8 * 1 * 1024 * 1024) // (8192 * 2) = 512 + (8192, 1, 2, 512), + # Larger hidden size reduces token threshold + # (8 * 1 * 1024 * 1024) // (16384 * 2) = 256 + (16384, 1, 2, 256), + # Larger tp_size increases token threshold + # (8 * 4 * 1024 * 1024) // (8192 * 2) = 2048 + (8192, 4, 2, 2048), + # Larger element_size (fp32) reduces token threshold + # (8 * 2 * 1024 * 1024) // (8192 * 4) = 512 + (8192, 2, 4, 512), + ], + ) + def test_threshold_calculation_variations( + self, mock_cuda_platform, hidden_size, tp_size, element_size, expected + ): + """Test threshold calculation with various parameter combinations.""" + with mock_cuda_platform(capability=(9, 0)): + result = get_sequence_parallelism_threshold( + hidden_size=hidden_size, + tp_size=tp_size, + element_size=element_size, + ) + assert result == expected + + def test_hidden_size_boundary(self, mock_cuda_platform): + """Test behavior at the exact hidden_size boundary.""" + with mock_cuda_platform(capability=(9, 0)): + # Just below threshold + result = get_sequence_parallelism_threshold( + hidden_size=SP_MIN_HIDDEN_SIZE[90] - 1, + tp_size=2, + element_size=2, + ) + assert result is None + + # Exactly at threshold + result = get_sequence_parallelism_threshold( + hidden_size=SP_MIN_HIDDEN_SIZE[90], + tp_size=2, + element_size=2, + ) + assert result is not None diff --git a/vllm/compilation/passes/fusion/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py index 5fb932d72..63de85932 100644 --- a/vllm/compilation/passes/fusion/sequence_parallelism.py +++ b/vllm/compilation/passes/fusion/sequence_parallelism.py @@ -27,6 +27,63 @@ from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNo logger = init_logger(__name__) +# Min hidden size per device capability for sequence parallelism +# Only apply sequence parallelism for models with hidden_size >= threshold +SP_MIN_HIDDEN_SIZE: dict[int, int] = { + 90: 8192, # H100: only for models with hidden_size >= 8192 +} + +# Min size per GPU per device capability for sequence parallelism +# Total min size = min_per_gpu_size * tp_size +# This ensures the threshold scales appropriately with tensor parallelism +SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = { + 90: 8, # 8MB per GPU for H100 +} + + +def get_sequence_parallelism_threshold( + hidden_size: int, + tp_size: int, + element_size: int, +) -> int | None: + """ + Calculate the minimum token threshold for applying sequence parallelism. + + Returns None if sequence parallelism should not be applied based on model size. + + Branching logic based on device capability: + - Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability] + - If not, returns None (SP disabled for small models on this device) + - If yes, calculates threshold based on per-GPU size + + Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) // + (hidden_size * element_size) + """ + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + return None + + capability = current_platform.get_device_capability() + if capability is None: + return None + device_capability = capability.to_int() + + # Check if device has configured thresholds + min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability) + min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability) + + if min_hidden_size is None or min_per_gpu_size_mb is None: + return None + + # Only apply sequence parallelism for models meeting the size threshold + if hidden_size < min_hidden_size: + return None + + MiB = 1024 * 1024 + min_size = min_per_gpu_size_mb * MiB * tp_size + return int(min_size // (hidden_size * element_size)) + def get_first_out_wrapper( fn: Callable[..., Sequence[torch.Tensor]], @@ -309,6 +366,23 @@ class SequenceParallelismPass(VllmPatternMatcherPass): def __init__(self, config: VllmConfig) -> None: super().__init__(config) + # Get min_token_num threshold + # Read min_token_num from config (calculated during config init) + self.min_token_num = None + if config.model_config is not None: + pass_config = config.compilation_config.pass_config + self.min_token_num = pass_config.sp_min_token_num + + if self.min_token_num is not None: + # Take the min to avoid exceeding max_num_batched_tokens + max_batched = config.scheduler_config.max_num_batched_tokens + if max_batched is not None: + self.min_token_num = min(self.min_token_num, max_batched) + logger.debug_once( + f"Sequence parallelism min token threshold: {self.min_token_num}", + scope="global", + ) + # Used to clean up redundant views created temporarily # to circumvent residual shape change issues self.noop_cleanup = NoOpEliminationPass(config) @@ -339,29 +413,36 @@ class SequenceParallelismPass(VllmPatternMatcherPass): self.dump_patterns(config, self.patterns) def is_applicable_for_range(self, compile_range: Range) -> bool: - # When sequence parallelism is enabled, the residual tensor from RMSNorm - # needs to be split along the sequence dimension. However, this dimension - # is symbolic during piecewise compilation, and splitting symbolic shapes - # is not supported. - # - # This pass is therefore only applied when the sequence dimension is - # concrete: - # 1. In full-graph compilation mode (no Dynamo splitting ops are used). - # For this case we always pad num_tokens to be a multiple of - # tensor_parallel_size, so there's no need to check shape % tp_size == 0. - # 2. For specific shape provided during compilation (e.g., from - # `compile_sizes`), which must be divisible by the tensor-parallel - # size. + """ + Determines if sequence parallelism should be applied for the given + compile range. + + SP is only beneficial for larger batch sizes where the communication + overhead is amortized. For small batches, the overhead of splitting + and gathering tensors across TP ranks outweighs the benefits. + + Returns False (SP disabled) when: + - Using piecewise compilation with non-concrete or TP-indivisible sizes + - min_token_num is None (SP disabled for this device/config) + - The compile range starts below the minimum token threshold + """ + # For piecewise compilation (not using inductor graph partition), + # we need concrete sizes that are divisible by TP for correct splitting if ( - not self.compilation_config.splitting_ops - or self.compilation_config.use_inductor_graph_partition + not self.compilation_config.use_inductor_graph_partition + and self.compilation_config.splitting_ops ): - return True - tp_size = get_tensor_model_parallel_world_size() - result: bool = (compile_range.is_single_size()) and ( - compile_range.end % tp_size == 0 - ) - return result + tp_size = get_tensor_model_parallel_world_size() + if not compile_range.is_single_size() or compile_range.end % tp_size != 0: + return False + + # min_token_num is None when SP is disabled for this device/config + # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size) + if self.min_token_num is None: + return False + + # Only apply SP when batch size meets the minimum threshold + return compile_range.start >= self.min_token_num @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ab6f3da06..d22e9a96e 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -118,7 +118,9 @@ class PassConfig: eliminate_noops: bool = Field(default=True) """Eliminate no-op ops.""" enable_sp: bool = Field(default=None) - """Enable sequence parallelism.""" + """Enable sequence parallelism. Requires TP>1. Automatically disabled + if the model's hidden_size is too small for SP to be beneficial + (threshold is device-capability dependent).""" fuse_gemm_comms: bool = Field(default=None) """Enable async TP.""" fuse_allreduce_rms: bool = Field(default=None) @@ -155,6 +157,11 @@ class PassConfig: 8: 1, # 1MB }, }, where key is the device capability""" + sp_min_token_num: int | None = None + """The minimum number of tokens above which vllm should use + sequence parallelism. Specified as an integer token count. + Unspecified will fallback to default values which are compute + capability and world size dependent.""" # TODO(luka) better pass enabling system. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ef71a05d3..fba3c64a9 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -853,8 +853,33 @@ class VllmConfig: logger.warning("Sequence Parallelism requires TP>1, disabling") self.compilation_config.pass_config.enable_sp = False self.compilation_config.pass_config.fuse_gemm_comms = False + else: + # Compute SP threshold early; disable if None (model too + # small) before +rms_norm gets forced into custom_ops. + pass_config = self.compilation_config.pass_config + if pass_config.sp_min_token_num is None: + from vllm.compilation.passes.fusion.sequence_parallelism import ( + get_sequence_parallelism_threshold, + ) - elif "-rms_norm" in self.compilation_config.custom_ops: + tp_size = self.parallel_config.tensor_parallel_size + hidden_size = self.model_config.get_hidden_size() + element_size = self.model_config.dtype.itemsize + pass_config.sp_min_token_num = get_sequence_parallelism_threshold( + hidden_size, tp_size, element_size + ) + + if pass_config.sp_min_token_num is None: + logger.warning( + "Model hidden_size too small for the SP " + "threshold heuristic, disabling. To force SP, " + "set pass_config.sp_min_token_num manually." + ) + self.compilation_config.pass_config.enable_sp = False + self.compilation_config.pass_config.fuse_gemm_comms = False + + if self.compilation_config.pass_config.enable_sp: + if "-rms_norm" in self.compilation_config.custom_ops: logger.warning( "RMS norm force disabled, sequence parallelism might break" ) @@ -1456,6 +1481,36 @@ class VllmConfig: "allreduce-rms fusion will be enabled for all num_tokens." ) + # Add the compile ranges for sequence parallelism + if compilation_config.pass_config.enable_sp: + pass_config = compilation_config.pass_config + + # Calculate min_token_num if not explicitly provided + # User override works regardless of hidden_size + if pass_config.sp_min_token_num is None: + from vllm.compilation.passes.fusion.sequence_parallelism import ( + get_sequence_parallelism_threshold, + ) + + tp_size = self.parallel_config.tensor_parallel_size + hidden_size = self.model_config.get_hidden_size() + element_size = self.model_config.dtype.itemsize + pass_config.sp_min_token_num = get_sequence_parallelism_threshold( + hidden_size, tp_size, element_size + ) + + min_token_num = pass_config.sp_min_token_num + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if min_token_num is not None and ( + max_num_batched_tokens is not None + and min_token_num < max_num_batched_tokens + and min_token_num > 1 + ): + # Add split point at min_token_num - 1 to ensure SP applies + # starting from min_token_num + # This creates ranges: [1, min-1] (no SP), [min, max] (SP applies) + computed_compile_ranges_split_points.append(min_token_num - 1) + if compilation_config.pass_config.fuse_rope_kvcache: max_token_num = ( compilation_config.pass_config.rope_kvcache_fusion_max_token_num