[torch.compile] Sequence Parallelism threshold compile ranges (#28672)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: Jason Li <jasonlizhengjian@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
34
tests/compile/conftest.py
Normal file
34
tests/compile/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cuda_platform():
|
||||
"""
|
||||
Fixture that returns a factory for creating mocked CUDA platforms.
|
||||
|
||||
Usage:
|
||||
def test_something(mock_cuda_platform):
|
||||
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
|
||||
# test code
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
|
||||
mock_platform = MagicMock()
|
||||
mock_platform.is_cuda.return_value = is_cuda
|
||||
if capability is not None:
|
||||
mock_platform.get_device_capability.return_value = DeviceCapability(
|
||||
*capability
|
||||
)
|
||||
with patch("vllm.platforms.current_platform", mock_platform):
|
||||
yield mock_platform
|
||||
|
||||
return _mock_platform
|
||||
@@ -94,7 +94,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
run_model(full_compilation_config, model_name, **model_kwargs)
|
||||
|
||||
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
||||
assert num_compile_ranges in [1, 2]
|
||||
assert num_compile_ranges in [1, 2, 3]
|
||||
|
||||
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
|
||||
print("Fusion results:")
|
||||
@@ -107,12 +107,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
# Now check the matches
|
||||
for match_name in matches_check:
|
||||
num_ranges_activated = (
|
||||
1 if match_name == "ar_rms_fusion" else num_compile_ranges
|
||||
)
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
|
||||
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
|
||||
|
||||
# AR+RMS skips the largest range; SP skips the smallest.
|
||||
# When both are enabled, AR+RMS activation count is
|
||||
# model-dependent (hidden_size affects threshold), so derive
|
||||
# from log data.
|
||||
if (
|
||||
match_name == "ar_rms_fusion"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
assert (
|
||||
len(log_matches) >= tp_size and len(log_matches) % tp_size == 0
|
||||
), (
|
||||
f"Expected multiple of {tp_size} ar_rms log entries, "
|
||||
f"found {len(log_matches)}"
|
||||
)
|
||||
num_ranges_activated = len(log_matches) // tp_size
|
||||
elif (
|
||||
match_name in ("ar_rms_fusion", "sequence_parallel")
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
num_ranges_activated = num_compile_ranges - 1
|
||||
else:
|
||||
num_ranges_activated = num_compile_ranges
|
||||
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
assert len(log_matches) == n_expected, (
|
||||
f"Could not find {n_expected} {match_name} "
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
@@ -122,8 +143,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
|
||||
# AR+rms+quant takes precedence over rms+quant if activated.
|
||||
# That means we get full matching where ar+rms+quant was not activated,
|
||||
# and less where it was
|
||||
# That means we get full matching where ar+rms+quant was not
|
||||
# activated, and less where it was (only the smallest range).
|
||||
assert sum(m == expected_matches for m in log_matches) == tp_size * (
|
||||
num_ranges_activated - 1
|
||||
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
|
||||
@@ -135,6 +156,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
|
||||
f"where ar+rms+quant was activated"
|
||||
)
|
||||
elif (
|
||||
match_name == "async_tp"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
# AsyncTP only finds patterns on ranges where SP ran.
|
||||
n_sp_ranges = num_compile_ranges - 1
|
||||
assert (
|
||||
sum(m == expected_matches for m in log_matches)
|
||||
== tp_size * n_sp_ranges
|
||||
), (
|
||||
f"Expecting {expected_matches} async_tp on "
|
||||
f"{tp_size * n_sp_ranges} SP-range entries, "
|
||||
f"found: {log_matches}"
|
||||
)
|
||||
assert sum(m == 0 for m in log_matches) == tp_size, (
|
||||
f"Expecting 0 async_tp on {tp_size} small-range entries "
|
||||
f"(no SP), found: {log_matches}"
|
||||
)
|
||||
elif (
|
||||
match_name == "ar_rms_fusion"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
# SP consumes allreduce patterns first, so AR+RMS finds
|
||||
# full matches only on the smallest range (no SP).
|
||||
assert sum(m == expected_matches for m in log_matches) == tp_size, (
|
||||
f"Expecting {expected_matches} ar_rms on "
|
||||
f"{tp_size} small-range entries, found: {log_matches}"
|
||||
)
|
||||
assert sum(m == 0 for m in log_matches) == tp_size * (
|
||||
num_ranges_activated - 1
|
||||
), (
|
||||
f"Expecting 0 ar_rms on "
|
||||
f"{tp_size * (num_ranges_activated - 1)} large-range "
|
||||
f"entries (SP took precedence), found: {log_matches}"
|
||||
)
|
||||
else:
|
||||
expected_matches_list = [expected_matches] * n_expected
|
||||
assert sorted(log_matches) == expected_matches_list, (
|
||||
@@ -142,7 +200,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"found: {sorted(log_matches)}"
|
||||
)
|
||||
|
||||
if match_name == "ar_rms_fusion":
|
||||
if match_name == "ar_rms_fusion" and num_compile_ranges >= 2:
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping "
|
||||
r".*AllReduceFusionPass.* with compile range",
|
||||
@@ -155,4 +213,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
if match_name == "sequence_parallel" and num_compile_ranges >= 2:
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping "
|
||||
r".*SequenceParallelismPass.* with compile range",
|
||||
log_holder.text,
|
||||
)
|
||||
|
||||
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
|
||||
assert len(log_matches) == n_expected, (
|
||||
f'Could not find {n_expected} "Skipping SequenceParallelismPass" '
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
@@ -66,6 +66,9 @@ def test_tp2_async_tp_fp8_fusions(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=False,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -123,6 +126,9 @@ def test_tp2_async_tp_fusions(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=False,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -141,3 +147,130 @@ def test_tp2_async_tp_fusions(
|
||||
matches_check,
|
||||
tp_size=2,
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
[llama3_8b_fp8, llama4_scout_fp8],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
def test_tp2_sp_ar_rms_fp8_fusions(
|
||||
model_name: str,
|
||||
matches_fn: Callable[[int], Matches],
|
||||
model_kwargs: dict,
|
||||
hf_overrides: Callable[[int], dict],
|
||||
attn_backend: AttentionBackendCase,
|
||||
n_layers: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
run_e2e_fusion_test,
|
||||
monkeypatch,
|
||||
):
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
if is_blackwell():
|
||||
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
|
||||
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
|
||||
|
||||
# Reduce size of model and skip weight loading time
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
custom_ops=custom_ops.split(","),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
fuse_attn_quant=True,
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=True,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
matches_check = [
|
||||
"rms_quant_fusion",
|
||||
"act_quant_fusion",
|
||||
"norm_rope_fusion",
|
||||
"attn_quant_fusion",
|
||||
"ar_rms_fusion",
|
||||
"sequence_parallel",
|
||||
"async_tp",
|
||||
]
|
||||
|
||||
run_e2e_fusion_test(
|
||||
model_name,
|
||||
matches,
|
||||
model_kwargs,
|
||||
attn_backend,
|
||||
compilation_config,
|
||||
matches_check,
|
||||
tp_size=2,
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
[llama3_8b, qwen3_a3b],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
def test_tp2_sp_ar_rms_fusions(
|
||||
model_name: str,
|
||||
matches_fn: Callable[[int], Matches],
|
||||
model_kwargs: dict,
|
||||
hf_overrides: Callable[[int], dict],
|
||||
attn_backend: AttentionBackendCase,
|
||||
n_layers: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
run_e2e_fusion_test,
|
||||
):
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
# Reduce size of model and skip weight loading time
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
custom_ops=custom_ops.split(","),
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=True,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
matches_check = [
|
||||
"norm_rope_fusion",
|
||||
"ar_rms_fusion",
|
||||
"sequence_parallel",
|
||||
"async_tp",
|
||||
]
|
||||
|
||||
run_e2e_fusion_test(
|
||||
model_name,
|
||||
matches,
|
||||
model_kwargs,
|
||||
attn_backend,
|
||||
compilation_config,
|
||||
matches_check,
|
||||
tp_size=2,
|
||||
)
|
||||
|
||||
@@ -421,6 +421,7 @@ def test_cudagraph_sizes_post_init(
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
eliminate_noops=True,
|
||||
sp_min_token_num=512 if enable_sp else None,
|
||||
),
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
)
|
||||
|
||||
110
tests/compile/test_sequence_parallelism_threshold.py
Normal file
110
tests/compile/test_sequence_parallelism_threshold.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
SP_MIN_HIDDEN_SIZE,
|
||||
SP_MIN_PER_GPU_SIZE_MB,
|
||||
get_sequence_parallelism_threshold,
|
||||
)
|
||||
|
||||
|
||||
class TestGetSequenceParallelismThreshold:
|
||||
"""Tests for get_sequence_parallelism_threshold function."""
|
||||
|
||||
def test_non_cuda_returns_none(self, mock_cuda_platform):
|
||||
"""Non-CUDA platforms should return None."""
|
||||
with mock_cuda_platform(is_cuda=False):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=8192, tp_size=2, element_size=2
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_unsupported_device_capability_returns_none(self, mock_cuda_platform):
|
||||
"""Unsupported device capabilities (e.g., sm80) should return None."""
|
||||
with mock_cuda_platform(capability=(8, 0)):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=8192, tp_size=2, element_size=2
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_small_hidden_size_returns_none(self, mock_cuda_platform):
|
||||
"""H100 with hidden_size below threshold should return None."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=4096,
|
||||
tp_size=2,
|
||||
element_size=2, # 4096 < 8192
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_h100_large_model_returns_threshold(self, mock_cuda_platform):
|
||||
"""H100 with large enough hidden_size should return calculated threshold."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
hidden_size = 8192
|
||||
tp_size = 2
|
||||
element_size = 2 # float16/bfloat16
|
||||
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=hidden_size,
|
||||
tp_size=tp_size,
|
||||
element_size=element_size,
|
||||
)
|
||||
|
||||
# Verify calculation: (8 * 2 * 1024 * 1024) // (8192 * 2) = 1024
|
||||
MiB = 1024 * 1024
|
||||
expected = int(
|
||||
(SP_MIN_PER_GPU_SIZE_MB[90] * tp_size * MiB)
|
||||
// (hidden_size * element_size)
|
||||
)
|
||||
assert result == expected
|
||||
assert result == 1024
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size,tp_size,element_size,expected",
|
||||
[
|
||||
# Boundary: exactly at min hidden size threshold, tp_size=1
|
||||
# (8 * 1 * 1024 * 1024) // (8192 * 2) = 512
|
||||
(8192, 1, 2, 512),
|
||||
# Larger hidden size reduces token threshold
|
||||
# (8 * 1 * 1024 * 1024) // (16384 * 2) = 256
|
||||
(16384, 1, 2, 256),
|
||||
# Larger tp_size increases token threshold
|
||||
# (8 * 4 * 1024 * 1024) // (8192 * 2) = 2048
|
||||
(8192, 4, 2, 2048),
|
||||
# Larger element_size (fp32) reduces token threshold
|
||||
# (8 * 2 * 1024 * 1024) // (8192 * 4) = 512
|
||||
(8192, 2, 4, 512),
|
||||
],
|
||||
)
|
||||
def test_threshold_calculation_variations(
|
||||
self, mock_cuda_platform, hidden_size, tp_size, element_size, expected
|
||||
):
|
||||
"""Test threshold calculation with various parameter combinations."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=hidden_size,
|
||||
tp_size=tp_size,
|
||||
element_size=element_size,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_hidden_size_boundary(self, mock_cuda_platform):
|
||||
"""Test behavior at the exact hidden_size boundary."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
# Just below threshold
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=SP_MIN_HIDDEN_SIZE[90] - 1,
|
||||
tp_size=2,
|
||||
element_size=2,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Exactly at threshold
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=SP_MIN_HIDDEN_SIZE[90],
|
||||
tp_size=2,
|
||||
element_size=2,
|
||||
)
|
||||
assert result is not None
|
||||
@@ -27,6 +27,63 @@ from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNo
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Min hidden size per device capability for sequence parallelism
|
||||
# Only apply sequence parallelism for models with hidden_size >= threshold
|
||||
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
|
||||
90: 8192, # H100: only for models with hidden_size >= 8192
|
||||
}
|
||||
|
||||
# Min size per GPU per device capability for sequence parallelism
|
||||
# Total min size = min_per_gpu_size * tp_size
|
||||
# This ensures the threshold scales appropriately with tensor parallelism
|
||||
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
|
||||
90: 8, # 8MB per GPU for H100
|
||||
}
|
||||
|
||||
|
||||
def get_sequence_parallelism_threshold(
|
||||
hidden_size: int,
|
||||
tp_size: int,
|
||||
element_size: int,
|
||||
) -> int | None:
|
||||
"""
|
||||
Calculate the minimum token threshold for applying sequence parallelism.
|
||||
|
||||
Returns None if sequence parallelism should not be applied based on model size.
|
||||
|
||||
Branching logic based on device capability:
|
||||
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
|
||||
- If not, returns None (SP disabled for small models on this device)
|
||||
- If yes, calculates threshold based on per-GPU size
|
||||
|
||||
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
|
||||
(hidden_size * element_size)
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return None
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is None:
|
||||
return None
|
||||
device_capability = capability.to_int()
|
||||
|
||||
# Check if device has configured thresholds
|
||||
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
|
||||
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
|
||||
|
||||
if min_hidden_size is None or min_per_gpu_size_mb is None:
|
||||
return None
|
||||
|
||||
# Only apply sequence parallelism for models meeting the size threshold
|
||||
if hidden_size < min_hidden_size:
|
||||
return None
|
||||
|
||||
MiB = 1024 * 1024
|
||||
min_size = min_per_gpu_size_mb * MiB * tp_size
|
||||
return int(min_size // (hidden_size * element_size))
|
||||
|
||||
|
||||
def get_first_out_wrapper(
|
||||
fn: Callable[..., Sequence[torch.Tensor]],
|
||||
@@ -309,6 +366,23 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Get min_token_num threshold
|
||||
# Read min_token_num from config (calculated during config init)
|
||||
self.min_token_num = None
|
||||
if config.model_config is not None:
|
||||
pass_config = config.compilation_config.pass_config
|
||||
self.min_token_num = pass_config.sp_min_token_num
|
||||
|
||||
if self.min_token_num is not None:
|
||||
# Take the min to avoid exceeding max_num_batched_tokens
|
||||
max_batched = config.scheduler_config.max_num_batched_tokens
|
||||
if max_batched is not None:
|
||||
self.min_token_num = min(self.min_token_num, max_batched)
|
||||
logger.debug_once(
|
||||
f"Sequence parallelism min token threshold: {self.min_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
# Used to clean up redundant views created temporarily
|
||||
# to circumvent residual shape change issues
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
@@ -339,29 +413,36 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# When sequence parallelism is enabled, the residual tensor from RMSNorm
|
||||
# needs to be split along the sequence dimension. However, this dimension
|
||||
# is symbolic during piecewise compilation, and splitting symbolic shapes
|
||||
# is not supported.
|
||||
#
|
||||
# This pass is therefore only applied when the sequence dimension is
|
||||
# concrete:
|
||||
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
|
||||
# For this case we always pad num_tokens to be a multiple of
|
||||
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
|
||||
# 2. For specific shape provided during compilation (e.g., from
|
||||
# `compile_sizes`), which must be divisible by the tensor-parallel
|
||||
# size.
|
||||
"""
|
||||
Determines if sequence parallelism should be applied for the given
|
||||
compile range.
|
||||
|
||||
SP is only beneficial for larger batch sizes where the communication
|
||||
overhead is amortized. For small batches, the overhead of splitting
|
||||
and gathering tensors across TP ranks outweighs the benefits.
|
||||
|
||||
Returns False (SP disabled) when:
|
||||
- Using piecewise compilation with non-concrete or TP-indivisible sizes
|
||||
- min_token_num is None (SP disabled for this device/config)
|
||||
- The compile range starts below the minimum token threshold
|
||||
"""
|
||||
# For piecewise compilation (not using inductor graph partition),
|
||||
# we need concrete sizes that are divisible by TP for correct splitting
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
not self.compilation_config.use_inductor_graph_partition
|
||||
and self.compilation_config.splitting_ops
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
result: bool = (compile_range.is_single_size()) and (
|
||||
compile_range.end % tp_size == 0
|
||||
)
|
||||
return result
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
|
||||
return False
|
||||
|
||||
# min_token_num is None when SP is disabled for this device/config
|
||||
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
|
||||
if self.min_token_num is None:
|
||||
return False
|
||||
|
||||
# Only apply SP when batch size meets the minimum threshold
|
||||
return compile_range.start >= self.min_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
|
||||
@@ -118,7 +118,9 @@ class PassConfig:
|
||||
eliminate_noops: bool = Field(default=True)
|
||||
"""Eliminate no-op ops."""
|
||||
enable_sp: bool = Field(default=None)
|
||||
"""Enable sequence parallelism."""
|
||||
"""Enable sequence parallelism. Requires TP>1. Automatically disabled
|
||||
if the model's hidden_size is too small for SP to be beneficial
|
||||
(threshold is device-capability dependent)."""
|
||||
fuse_gemm_comms: bool = Field(default=None)
|
||||
"""Enable async TP."""
|
||||
fuse_allreduce_rms: bool = Field(default=None)
|
||||
@@ -155,6 +157,11 @@ class PassConfig:
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}, where key is the device capability"""
|
||||
sp_min_token_num: int | None = None
|
||||
"""The minimum number of tokens above which vllm should use
|
||||
sequence parallelism. Specified as an integer token count.
|
||||
Unspecified will fallback to default values which are compute
|
||||
capability and world size dependent."""
|
||||
|
||||
# TODO(luka) better pass enabling system.
|
||||
|
||||
|
||||
@@ -853,8 +853,33 @@ class VllmConfig:
|
||||
logger.warning("Sequence Parallelism requires TP>1, disabling")
|
||||
self.compilation_config.pass_config.enable_sp = False
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
else:
|
||||
# Compute SP threshold early; disable if None (model too
|
||||
# small) before +rms_norm gets forced into custom_ops.
|
||||
pass_config = self.compilation_config.pass_config
|
||||
if pass_config.sp_min_token_num is None:
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
get_sequence_parallelism_threshold,
|
||||
)
|
||||
|
||||
elif "-rms_norm" in self.compilation_config.custom_ops:
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
hidden_size = self.model_config.get_hidden_size()
|
||||
element_size = self.model_config.dtype.itemsize
|
||||
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
|
||||
hidden_size, tp_size, element_size
|
||||
)
|
||||
|
||||
if pass_config.sp_min_token_num is None:
|
||||
logger.warning(
|
||||
"Model hidden_size too small for the SP "
|
||||
"threshold heuristic, disabling. To force SP, "
|
||||
"set pass_config.sp_min_token_num manually."
|
||||
)
|
||||
self.compilation_config.pass_config.enable_sp = False
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
if "-rms_norm" in self.compilation_config.custom_ops:
|
||||
logger.warning(
|
||||
"RMS norm force disabled, sequence parallelism might break"
|
||||
)
|
||||
@@ -1456,6 +1481,36 @@ class VllmConfig:
|
||||
"allreduce-rms fusion will be enabled for all num_tokens."
|
||||
)
|
||||
|
||||
# Add the compile ranges for sequence parallelism
|
||||
if compilation_config.pass_config.enable_sp:
|
||||
pass_config = compilation_config.pass_config
|
||||
|
||||
# Calculate min_token_num if not explicitly provided
|
||||
# User override works regardless of hidden_size
|
||||
if pass_config.sp_min_token_num is None:
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
get_sequence_parallelism_threshold,
|
||||
)
|
||||
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
hidden_size = self.model_config.get_hidden_size()
|
||||
element_size = self.model_config.dtype.itemsize
|
||||
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
|
||||
hidden_size, tp_size, element_size
|
||||
)
|
||||
|
||||
min_token_num = pass_config.sp_min_token_num
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
if min_token_num is not None and (
|
||||
max_num_batched_tokens is not None
|
||||
and min_token_num < max_num_batched_tokens
|
||||
and min_token_num > 1
|
||||
):
|
||||
# Add split point at min_token_num - 1 to ensure SP applies
|
||||
# starting from min_token_num
|
||||
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
|
||||
computed_compile_ranges_split_points.append(min_token_num - 1)
|
||||
|
||||
if compilation_config.pass_config.fuse_rope_kvcache:
|
||||
max_token_num = (
|
||||
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
|
||||
|
||||
Reference in New Issue
Block a user