[torch.compile] Sequence Parallelism threshold compile ranges (#28672)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: Jason Li <jasonlizhengjian@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
34
tests/compile/conftest.py
Normal file
34
tests/compile/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.platforms.interface import DeviceCapability
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_cuda_platform():
|
||||||
|
"""
|
||||||
|
Fixture that returns a factory for creating mocked CUDA platforms.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
def test_something(mock_cuda_platform):
|
||||||
|
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
|
||||||
|
# test code
|
||||||
|
"""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
|
||||||
|
mock_platform = MagicMock()
|
||||||
|
mock_platform.is_cuda.return_value = is_cuda
|
||||||
|
if capability is not None:
|
||||||
|
mock_platform.get_device_capability.return_value = DeviceCapability(
|
||||||
|
*capability
|
||||||
|
)
|
||||||
|
with patch("vllm.platforms.current_platform", mock_platform):
|
||||||
|
yield mock_platform
|
||||||
|
|
||||||
|
return _mock_platform
|
||||||
@@ -94,7 +94,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
|||||||
run_model(full_compilation_config, model_name, **model_kwargs)
|
run_model(full_compilation_config, model_name, **model_kwargs)
|
||||||
|
|
||||||
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
||||||
assert num_compile_ranges in [1, 2]
|
assert num_compile_ranges in [1, 2, 3]
|
||||||
|
|
||||||
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
|
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
|
||||||
print("Fusion results:")
|
print("Fusion results:")
|
||||||
@@ -107,12 +107,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
|||||||
|
|
||||||
# Now check the matches
|
# Now check the matches
|
||||||
for match_name in matches_check:
|
for match_name in matches_check:
|
||||||
num_ranges_activated = (
|
|
||||||
1 if match_name == "ar_rms_fusion" else num_compile_ranges
|
|
||||||
)
|
|
||||||
n_expected = tp_size * num_ranges_activated
|
|
||||||
|
|
||||||
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
|
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
|
||||||
|
|
||||||
|
# AR+RMS skips the largest range; SP skips the smallest.
|
||||||
|
# When both are enabled, AR+RMS activation count is
|
||||||
|
# model-dependent (hidden_size affects threshold), so derive
|
||||||
|
# from log data.
|
||||||
|
if (
|
||||||
|
match_name == "ar_rms_fusion"
|
||||||
|
and "sequence_parallel" in matches_check
|
||||||
|
and num_compile_ranges >= 2
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
len(log_matches) >= tp_size and len(log_matches) % tp_size == 0
|
||||||
|
), (
|
||||||
|
f"Expected multiple of {tp_size} ar_rms log entries, "
|
||||||
|
f"found {len(log_matches)}"
|
||||||
|
)
|
||||||
|
num_ranges_activated = len(log_matches) // tp_size
|
||||||
|
elif (
|
||||||
|
match_name in ("ar_rms_fusion", "sequence_parallel")
|
||||||
|
and num_compile_ranges >= 2
|
||||||
|
):
|
||||||
|
num_ranges_activated = num_compile_ranges - 1
|
||||||
|
else:
|
||||||
|
num_ranges_activated = num_compile_ranges
|
||||||
|
|
||||||
|
n_expected = tp_size * num_ranges_activated
|
||||||
assert len(log_matches) == n_expected, (
|
assert len(log_matches) == n_expected, (
|
||||||
f"Could not find {n_expected} {match_name} "
|
f"Could not find {n_expected} {match_name} "
|
||||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||||
@@ -122,8 +143,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
|||||||
|
|
||||||
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
|
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
|
||||||
# AR+rms+quant takes precedence over rms+quant if activated.
|
# AR+rms+quant takes precedence over rms+quant if activated.
|
||||||
# That means we get full matching where ar+rms+quant was not activated,
|
# That means we get full matching where ar+rms+quant was not
|
||||||
# and less where it was
|
# activated, and less where it was (only the smallest range).
|
||||||
assert sum(m == expected_matches for m in log_matches) == tp_size * (
|
assert sum(m == expected_matches for m in log_matches) == tp_size * (
|
||||||
num_ranges_activated - 1
|
num_ranges_activated - 1
|
||||||
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
|
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
|
||||||
@@ -135,6 +156,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
|||||||
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
|
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
|
||||||
f"where ar+rms+quant was activated"
|
f"where ar+rms+quant was activated"
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
match_name == "async_tp"
|
||||||
|
and "sequence_parallel" in matches_check
|
||||||
|
and num_compile_ranges >= 2
|
||||||
|
):
|
||||||
|
# AsyncTP only finds patterns on ranges where SP ran.
|
||||||
|
n_sp_ranges = num_compile_ranges - 1
|
||||||
|
assert (
|
||||||
|
sum(m == expected_matches for m in log_matches)
|
||||||
|
== tp_size * n_sp_ranges
|
||||||
|
), (
|
||||||
|
f"Expecting {expected_matches} async_tp on "
|
||||||
|
f"{tp_size * n_sp_ranges} SP-range entries, "
|
||||||
|
f"found: {log_matches}"
|
||||||
|
)
|
||||||
|
assert sum(m == 0 for m in log_matches) == tp_size, (
|
||||||
|
f"Expecting 0 async_tp on {tp_size} small-range entries "
|
||||||
|
f"(no SP), found: {log_matches}"
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
match_name == "ar_rms_fusion"
|
||||||
|
and "sequence_parallel" in matches_check
|
||||||
|
and num_compile_ranges >= 2
|
||||||
|
):
|
||||||
|
# SP consumes allreduce patterns first, so AR+RMS finds
|
||||||
|
# full matches only on the smallest range (no SP).
|
||||||
|
assert sum(m == expected_matches for m in log_matches) == tp_size, (
|
||||||
|
f"Expecting {expected_matches} ar_rms on "
|
||||||
|
f"{tp_size} small-range entries, found: {log_matches}"
|
||||||
|
)
|
||||||
|
assert sum(m == 0 for m in log_matches) == tp_size * (
|
||||||
|
num_ranges_activated - 1
|
||||||
|
), (
|
||||||
|
f"Expecting 0 ar_rms on "
|
||||||
|
f"{tp_size * (num_ranges_activated - 1)} large-range "
|
||||||
|
f"entries (SP took precedence), found: {log_matches}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
expected_matches_list = [expected_matches] * n_expected
|
expected_matches_list = [expected_matches] * n_expected
|
||||||
assert sorted(log_matches) == expected_matches_list, (
|
assert sorted(log_matches) == expected_matches_list, (
|
||||||
@@ -142,7 +200,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
|||||||
f"found: {sorted(log_matches)}"
|
f"found: {sorted(log_matches)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if match_name == "ar_rms_fusion":
|
if match_name == "ar_rms_fusion" and num_compile_ranges >= 2:
|
||||||
log_matches = re.findall(
|
log_matches = re.findall(
|
||||||
r"pass_manager.py:\d+] Skipping "
|
r"pass_manager.py:\d+] Skipping "
|
||||||
r".*AllReduceFusionPass.* with compile range",
|
r".*AllReduceFusionPass.* with compile range",
|
||||||
@@ -155,4 +213,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
|||||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if match_name == "sequence_parallel" and num_compile_ranges >= 2:
|
||||||
|
log_matches = re.findall(
|
||||||
|
r"pass_manager.py:\d+] Skipping "
|
||||||
|
r".*SequenceParallelismPass.* with compile range",
|
||||||
|
log_holder.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
|
||||||
|
assert len(log_matches) == n_expected, (
|
||||||
|
f'Could not find {n_expected} "Skipping SequenceParallelismPass" '
|
||||||
|
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||||
|
)
|
||||||
|
|
||||||
return run
|
return run
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ def test_tp2_async_tp_fp8_fusions(
|
|||||||
enable_qk_norm_rope_fusion=True,
|
enable_qk_norm_rope_fusion=True,
|
||||||
enable_sp=True,
|
enable_sp=True,
|
||||||
fuse_gemm_comms=True,
|
fuse_gemm_comms=True,
|
||||||
|
fuse_allreduce_rms=False,
|
||||||
|
# Override threshold for testing (models have small hidden_size)
|
||||||
|
sp_min_token_num=512,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -123,6 +126,9 @@ def test_tp2_async_tp_fusions(
|
|||||||
enable_qk_norm_rope_fusion=True,
|
enable_qk_norm_rope_fusion=True,
|
||||||
enable_sp=True,
|
enable_sp=True,
|
||||||
fuse_gemm_comms=True,
|
fuse_gemm_comms=True,
|
||||||
|
fuse_allreduce_rms=False,
|
||||||
|
# Override threshold for testing (models have small hidden_size)
|
||||||
|
sp_min_token_num=512,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,3 +147,130 @@ def test_tp2_async_tp_fusions(
|
|||||||
matches_check,
|
matches_check,
|
||||||
tp_size=2,
|
tp_size=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||||
|
[llama3_8b_fp8, llama4_scout_fp8],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
|
||||||
|
@pytest.mark.parametrize("n_layers", [4])
|
||||||
|
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
|
||||||
|
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||||
|
def test_tp2_sp_ar_rms_fp8_fusions(
|
||||||
|
model_name: str,
|
||||||
|
matches_fn: Callable[[int], Matches],
|
||||||
|
model_kwargs: dict,
|
||||||
|
hf_overrides: Callable[[int], dict],
|
||||||
|
attn_backend: AttentionBackendCase,
|
||||||
|
n_layers: int,
|
||||||
|
custom_ops: str,
|
||||||
|
inductor_graph_partition: bool,
|
||||||
|
run_e2e_fusion_test,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
matches = matches_fn(n_layers)
|
||||||
|
|
||||||
|
if is_blackwell():
|
||||||
|
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
|
||||||
|
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
|
||||||
|
|
||||||
|
# Reduce size of model and skip weight loading time
|
||||||
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
|
model_kwargs["load_format"] = "dummy"
|
||||||
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
|
||||||
|
compilation_config = dict(
|
||||||
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
|
custom_ops=custom_ops.split(","),
|
||||||
|
pass_config=PassConfig(
|
||||||
|
fuse_norm_quant=True,
|
||||||
|
fuse_act_quant=True,
|
||||||
|
fuse_attn_quant=True,
|
||||||
|
enable_qk_norm_rope_fusion=True,
|
||||||
|
enable_sp=True,
|
||||||
|
fuse_gemm_comms=True,
|
||||||
|
fuse_allreduce_rms=True,
|
||||||
|
# Override threshold for testing (models have small hidden_size)
|
||||||
|
sp_min_token_num=512,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
matches_check = [
|
||||||
|
"rms_quant_fusion",
|
||||||
|
"act_quant_fusion",
|
||||||
|
"norm_rope_fusion",
|
||||||
|
"attn_quant_fusion",
|
||||||
|
"ar_rms_fusion",
|
||||||
|
"sequence_parallel",
|
||||||
|
"async_tp",
|
||||||
|
]
|
||||||
|
|
||||||
|
run_e2e_fusion_test(
|
||||||
|
model_name,
|
||||||
|
matches,
|
||||||
|
model_kwargs,
|
||||||
|
attn_backend,
|
||||||
|
compilation_config,
|
||||||
|
matches_check,
|
||||||
|
tp_size=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||||
|
[llama3_8b, qwen3_a3b],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
|
||||||
|
@pytest.mark.parametrize("n_layers", [4])
|
||||||
|
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
|
||||||
|
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||||
|
def test_tp2_sp_ar_rms_fusions(
|
||||||
|
model_name: str,
|
||||||
|
matches_fn: Callable[[int], Matches],
|
||||||
|
model_kwargs: dict,
|
||||||
|
hf_overrides: Callable[[int], dict],
|
||||||
|
attn_backend: AttentionBackendCase,
|
||||||
|
n_layers: int,
|
||||||
|
custom_ops: str,
|
||||||
|
inductor_graph_partition: bool,
|
||||||
|
run_e2e_fusion_test,
|
||||||
|
):
|
||||||
|
matches = matches_fn(n_layers)
|
||||||
|
|
||||||
|
# Reduce size of model and skip weight loading time
|
||||||
|
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||||
|
model_kwargs["load_format"] = "dummy"
|
||||||
|
model_kwargs["max_model_len"] = 1024
|
||||||
|
|
||||||
|
compilation_config = dict(
|
||||||
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
|
custom_ops=custom_ops.split(","),
|
||||||
|
pass_config=PassConfig(
|
||||||
|
enable_qk_norm_rope_fusion=True,
|
||||||
|
enable_sp=True,
|
||||||
|
fuse_gemm_comms=True,
|
||||||
|
fuse_allreduce_rms=True,
|
||||||
|
# Override threshold for testing (models have small hidden_size)
|
||||||
|
sp_min_token_num=512,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
matches_check = [
|
||||||
|
"norm_rope_fusion",
|
||||||
|
"ar_rms_fusion",
|
||||||
|
"sequence_parallel",
|
||||||
|
"async_tp",
|
||||||
|
]
|
||||||
|
|
||||||
|
run_e2e_fusion_test(
|
||||||
|
model_name,
|
||||||
|
matches,
|
||||||
|
model_kwargs,
|
||||||
|
attn_backend,
|
||||||
|
compilation_config,
|
||||||
|
matches_check,
|
||||||
|
tp_size=2,
|
||||||
|
)
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ def test_cudagraph_sizes_post_init(
|
|||||||
fuse_norm_quant=True,
|
fuse_norm_quant=True,
|
||||||
fuse_act_quant=True,
|
fuse_act_quant=True,
|
||||||
eliminate_noops=True,
|
eliminate_noops=True,
|
||||||
|
sp_min_token_num=512 if enable_sp else None,
|
||||||
),
|
),
|
||||||
cudagraph_mode=cudagraph_mode,
|
cudagraph_mode=cudagraph_mode,
|
||||||
)
|
)
|
||||||
|
|||||||
110
tests/compile/test_sequence_parallelism_threshold.py
Normal file
110
tests/compile/test_sequence_parallelism_threshold.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||||
|
SP_MIN_HIDDEN_SIZE,
|
||||||
|
SP_MIN_PER_GPU_SIZE_MB,
|
||||||
|
get_sequence_parallelism_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSequenceParallelismThreshold:
|
||||||
|
"""Tests for get_sequence_parallelism_threshold function."""
|
||||||
|
|
||||||
|
def test_non_cuda_returns_none(self, mock_cuda_platform):
|
||||||
|
"""Non-CUDA platforms should return None."""
|
||||||
|
with mock_cuda_platform(is_cuda=False):
|
||||||
|
result = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size=8192, tp_size=2, element_size=2
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_unsupported_device_capability_returns_none(self, mock_cuda_platform):
|
||||||
|
"""Unsupported device capabilities (e.g., sm80) should return None."""
|
||||||
|
with mock_cuda_platform(capability=(8, 0)):
|
||||||
|
result = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size=8192, tp_size=2, element_size=2
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_small_hidden_size_returns_none(self, mock_cuda_platform):
|
||||||
|
"""H100 with hidden_size below threshold should return None."""
|
||||||
|
with mock_cuda_platform(capability=(9, 0)):
|
||||||
|
result = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size=4096,
|
||||||
|
tp_size=2,
|
||||||
|
element_size=2, # 4096 < 8192
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_h100_large_model_returns_threshold(self, mock_cuda_platform):
|
||||||
|
"""H100 with large enough hidden_size should return calculated threshold."""
|
||||||
|
with mock_cuda_platform(capability=(9, 0)):
|
||||||
|
hidden_size = 8192
|
||||||
|
tp_size = 2
|
||||||
|
element_size = 2 # float16/bfloat16
|
||||||
|
|
||||||
|
result = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
tp_size=tp_size,
|
||||||
|
element_size=element_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify calculation: (8 * 2 * 1024 * 1024) // (8192 * 2) = 1024
|
||||||
|
MiB = 1024 * 1024
|
||||||
|
expected = int(
|
||||||
|
(SP_MIN_PER_GPU_SIZE_MB[90] * tp_size * MiB)
|
||||||
|
// (hidden_size * element_size)
|
||||||
|
)
|
||||||
|
assert result == expected
|
||||||
|
assert result == 1024
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"hidden_size,tp_size,element_size,expected",
|
||||||
|
[
|
||||||
|
# Boundary: exactly at min hidden size threshold, tp_size=1
|
||||||
|
# (8 * 1 * 1024 * 1024) // (8192 * 2) = 512
|
||||||
|
(8192, 1, 2, 512),
|
||||||
|
# Larger hidden size reduces token threshold
|
||||||
|
# (8 * 1 * 1024 * 1024) // (16384 * 2) = 256
|
||||||
|
(16384, 1, 2, 256),
|
||||||
|
# Larger tp_size increases token threshold
|
||||||
|
# (8 * 4 * 1024 * 1024) // (8192 * 2) = 2048
|
||||||
|
(8192, 4, 2, 2048),
|
||||||
|
# Larger element_size (fp32) reduces token threshold
|
||||||
|
# (8 * 2 * 1024 * 1024) // (8192 * 4) = 512
|
||||||
|
(8192, 2, 4, 512),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_threshold_calculation_variations(
|
||||||
|
self, mock_cuda_platform, hidden_size, tp_size, element_size, expected
|
||||||
|
):
|
||||||
|
"""Test threshold calculation with various parameter combinations."""
|
||||||
|
with mock_cuda_platform(capability=(9, 0)):
|
||||||
|
result = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
tp_size=tp_size,
|
||||||
|
element_size=element_size,
|
||||||
|
)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
def test_hidden_size_boundary(self, mock_cuda_platform):
|
||||||
|
"""Test behavior at the exact hidden_size boundary."""
|
||||||
|
with mock_cuda_platform(capability=(9, 0)):
|
||||||
|
# Just below threshold
|
||||||
|
result = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size=SP_MIN_HIDDEN_SIZE[90] - 1,
|
||||||
|
tp_size=2,
|
||||||
|
element_size=2,
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Exactly at threshold
|
||||||
|
result = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size=SP_MIN_HIDDEN_SIZE[90],
|
||||||
|
tp_size=2,
|
||||||
|
element_size=2,
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
@@ -27,6 +27,63 @@ from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNo
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Min hidden size per device capability for sequence parallelism
|
||||||
|
# Only apply sequence parallelism for models with hidden_size >= threshold
|
||||||
|
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
|
||||||
|
90: 8192, # H100: only for models with hidden_size >= 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
# Min size per GPU per device capability for sequence parallelism
|
||||||
|
# Total min size = min_per_gpu_size * tp_size
|
||||||
|
# This ensures the threshold scales appropriately with tensor parallelism
|
||||||
|
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
|
||||||
|
90: 8, # 8MB per GPU for H100
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_sequence_parallelism_threshold(
|
||||||
|
hidden_size: int,
|
||||||
|
tp_size: int,
|
||||||
|
element_size: int,
|
||||||
|
) -> int | None:
|
||||||
|
"""
|
||||||
|
Calculate the minimum token threshold for applying sequence parallelism.
|
||||||
|
|
||||||
|
Returns None if sequence parallelism should not be applied based on model size.
|
||||||
|
|
||||||
|
Branching logic based on device capability:
|
||||||
|
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
|
||||||
|
- If not, returns None (SP disabled for small models on this device)
|
||||||
|
- If yes, calculates threshold based on per-GPU size
|
||||||
|
|
||||||
|
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
|
||||||
|
(hidden_size * element_size)
|
||||||
|
"""
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return None
|
||||||
|
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
if capability is None:
|
||||||
|
return None
|
||||||
|
device_capability = capability.to_int()
|
||||||
|
|
||||||
|
# Check if device has configured thresholds
|
||||||
|
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
|
||||||
|
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
|
||||||
|
|
||||||
|
if min_hidden_size is None or min_per_gpu_size_mb is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Only apply sequence parallelism for models meeting the size threshold
|
||||||
|
if hidden_size < min_hidden_size:
|
||||||
|
return None
|
||||||
|
|
||||||
|
MiB = 1024 * 1024
|
||||||
|
min_size = min_per_gpu_size_mb * MiB * tp_size
|
||||||
|
return int(min_size // (hidden_size * element_size))
|
||||||
|
|
||||||
|
|
||||||
def get_first_out_wrapper(
|
def get_first_out_wrapper(
|
||||||
fn: Callable[..., Sequence[torch.Tensor]],
|
fn: Callable[..., Sequence[torch.Tensor]],
|
||||||
@@ -309,6 +366,23 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
|||||||
def __init__(self, config: VllmConfig) -> None:
|
def __init__(self, config: VllmConfig) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Get min_token_num threshold
|
||||||
|
# Read min_token_num from config (calculated during config init)
|
||||||
|
self.min_token_num = None
|
||||||
|
if config.model_config is not None:
|
||||||
|
pass_config = config.compilation_config.pass_config
|
||||||
|
self.min_token_num = pass_config.sp_min_token_num
|
||||||
|
|
||||||
|
if self.min_token_num is not None:
|
||||||
|
# Take the min to avoid exceeding max_num_batched_tokens
|
||||||
|
max_batched = config.scheduler_config.max_num_batched_tokens
|
||||||
|
if max_batched is not None:
|
||||||
|
self.min_token_num = min(self.min_token_num, max_batched)
|
||||||
|
logger.debug_once(
|
||||||
|
f"Sequence parallelism min token threshold: {self.min_token_num}",
|
||||||
|
scope="global",
|
||||||
|
)
|
||||||
|
|
||||||
# Used to clean up redundant views created temporarily
|
# Used to clean up redundant views created temporarily
|
||||||
# to circumvent residual shape change issues
|
# to circumvent residual shape change issues
|
||||||
self.noop_cleanup = NoOpEliminationPass(config)
|
self.noop_cleanup = NoOpEliminationPass(config)
|
||||||
@@ -339,29 +413,36 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
|||||||
self.dump_patterns(config, self.patterns)
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||||
# When sequence parallelism is enabled, the residual tensor from RMSNorm
|
"""
|
||||||
# needs to be split along the sequence dimension. However, this dimension
|
Determines if sequence parallelism should be applied for the given
|
||||||
# is symbolic during piecewise compilation, and splitting symbolic shapes
|
compile range.
|
||||||
# is not supported.
|
|
||||||
#
|
SP is only beneficial for larger batch sizes where the communication
|
||||||
# This pass is therefore only applied when the sequence dimension is
|
overhead is amortized. For small batches, the overhead of splitting
|
||||||
# concrete:
|
and gathering tensors across TP ranks outweighs the benefits.
|
||||||
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
|
|
||||||
# For this case we always pad num_tokens to be a multiple of
|
Returns False (SP disabled) when:
|
||||||
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
|
- Using piecewise compilation with non-concrete or TP-indivisible sizes
|
||||||
# 2. For specific shape provided during compilation (e.g., from
|
- min_token_num is None (SP disabled for this device/config)
|
||||||
# `compile_sizes`), which must be divisible by the tensor-parallel
|
- The compile range starts below the minimum token threshold
|
||||||
# size.
|
"""
|
||||||
|
# For piecewise compilation (not using inductor graph partition),
|
||||||
|
# we need concrete sizes that are divisible by TP for correct splitting
|
||||||
if (
|
if (
|
||||||
not self.compilation_config.splitting_ops
|
not self.compilation_config.use_inductor_graph_partition
|
||||||
or self.compilation_config.use_inductor_graph_partition
|
and self.compilation_config.splitting_ops
|
||||||
):
|
):
|
||||||
return True
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
|
||||||
result: bool = (compile_range.is_single_size()) and (
|
return False
|
||||||
compile_range.end % tp_size == 0
|
|
||||||
)
|
# min_token_num is None when SP is disabled for this device/config
|
||||||
return result
|
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
|
||||||
|
if self.min_token_num is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Only apply SP when batch size meets the minimum threshold
|
||||||
|
return compile_range.start >= self.min_token_num
|
||||||
|
|
||||||
@VllmInductorPass.time_and_log
|
@VllmInductorPass.time_and_log
|
||||||
def __call__(self, graph: fx.Graph) -> None:
|
def __call__(self, graph: fx.Graph) -> None:
|
||||||
|
|||||||
@@ -118,7 +118,9 @@ class PassConfig:
|
|||||||
eliminate_noops: bool = Field(default=True)
|
eliminate_noops: bool = Field(default=True)
|
||||||
"""Eliminate no-op ops."""
|
"""Eliminate no-op ops."""
|
||||||
enable_sp: bool = Field(default=None)
|
enable_sp: bool = Field(default=None)
|
||||||
"""Enable sequence parallelism."""
|
"""Enable sequence parallelism. Requires TP>1. Automatically disabled
|
||||||
|
if the model's hidden_size is too small for SP to be beneficial
|
||||||
|
(threshold is device-capability dependent)."""
|
||||||
fuse_gemm_comms: bool = Field(default=None)
|
fuse_gemm_comms: bool = Field(default=None)
|
||||||
"""Enable async TP."""
|
"""Enable async TP."""
|
||||||
fuse_allreduce_rms: bool = Field(default=None)
|
fuse_allreduce_rms: bool = Field(default=None)
|
||||||
@@ -155,6 +157,11 @@ class PassConfig:
|
|||||||
8: 1, # 1MB
|
8: 1, # 1MB
|
||||||
},
|
},
|
||||||
}, where key is the device capability"""
|
}, where key is the device capability"""
|
||||||
|
sp_min_token_num: int | None = None
|
||||||
|
"""The minimum number of tokens above which vllm should use
|
||||||
|
sequence parallelism. Specified as an integer token count.
|
||||||
|
Unspecified will fallback to default values which are compute
|
||||||
|
capability and world size dependent."""
|
||||||
|
|
||||||
# TODO(luka) better pass enabling system.
|
# TODO(luka) better pass enabling system.
|
||||||
|
|
||||||
|
|||||||
@@ -853,8 +853,33 @@ class VllmConfig:
|
|||||||
logger.warning("Sequence Parallelism requires TP>1, disabling")
|
logger.warning("Sequence Parallelism requires TP>1, disabling")
|
||||||
self.compilation_config.pass_config.enable_sp = False
|
self.compilation_config.pass_config.enable_sp = False
|
||||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||||
|
else:
|
||||||
|
# Compute SP threshold early; disable if None (model too
|
||||||
|
# small) before +rms_norm gets forced into custom_ops.
|
||||||
|
pass_config = self.compilation_config.pass_config
|
||||||
|
if pass_config.sp_min_token_num is None:
|
||||||
|
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||||
|
get_sequence_parallelism_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
elif "-rms_norm" in self.compilation_config.custom_ops:
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
|
hidden_size = self.model_config.get_hidden_size()
|
||||||
|
element_size = self.model_config.dtype.itemsize
|
||||||
|
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size, tp_size, element_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if pass_config.sp_min_token_num is None:
|
||||||
|
logger.warning(
|
||||||
|
"Model hidden_size too small for the SP "
|
||||||
|
"threshold heuristic, disabling. To force SP, "
|
||||||
|
"set pass_config.sp_min_token_num manually."
|
||||||
|
)
|
||||||
|
self.compilation_config.pass_config.enable_sp = False
|
||||||
|
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||||
|
|
||||||
|
if self.compilation_config.pass_config.enable_sp:
|
||||||
|
if "-rms_norm" in self.compilation_config.custom_ops:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"RMS norm force disabled, sequence parallelism might break"
|
"RMS norm force disabled, sequence parallelism might break"
|
||||||
)
|
)
|
||||||
@@ -1456,6 +1481,36 @@ class VllmConfig:
|
|||||||
"allreduce-rms fusion will be enabled for all num_tokens."
|
"allreduce-rms fusion will be enabled for all num_tokens."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add the compile ranges for sequence parallelism
|
||||||
|
if compilation_config.pass_config.enable_sp:
|
||||||
|
pass_config = compilation_config.pass_config
|
||||||
|
|
||||||
|
# Calculate min_token_num if not explicitly provided
|
||||||
|
# User override works regardless of hidden_size
|
||||||
|
if pass_config.sp_min_token_num is None:
|
||||||
|
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||||
|
get_sequence_parallelism_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
|
hidden_size = self.model_config.get_hidden_size()
|
||||||
|
element_size = self.model_config.dtype.itemsize
|
||||||
|
pass_config.sp_min_token_num = get_sequence_parallelism_threshold(
|
||||||
|
hidden_size, tp_size, element_size
|
||||||
|
)
|
||||||
|
|
||||||
|
min_token_num = pass_config.sp_min_token_num
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
if min_token_num is not None and (
|
||||||
|
max_num_batched_tokens is not None
|
||||||
|
and min_token_num < max_num_batched_tokens
|
||||||
|
and min_token_num > 1
|
||||||
|
):
|
||||||
|
# Add split point at min_token_num - 1 to ensure SP applies
|
||||||
|
# starting from min_token_num
|
||||||
|
# This creates ranges: [1, min-1] (no SP), [min, max] (SP applies)
|
||||||
|
computed_compile_ranges_split_points.append(min_token_num - 1)
|
||||||
|
|
||||||
if compilation_config.pass_config.fuse_rope_kvcache:
|
if compilation_config.pass_config.fuse_rope_kvcache:
|
||||||
max_token_num = (
|
max_token_num = (
|
||||||
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
|
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
|
||||||
|
|||||||
Reference in New Issue
Block a user