[torch.compile] Sequence Parallelism threshold compile ranges (#28672)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: Jason Li <jasonlizhengjian@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
34
tests/compile/conftest.py
Normal file
34
tests/compile/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cuda_platform():
|
||||
"""
|
||||
Fixture that returns a factory for creating mocked CUDA platforms.
|
||||
|
||||
Usage:
|
||||
def test_something(mock_cuda_platform):
|
||||
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
|
||||
# test code
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
|
||||
mock_platform = MagicMock()
|
||||
mock_platform.is_cuda.return_value = is_cuda
|
||||
if capability is not None:
|
||||
mock_platform.get_device_capability.return_value = DeviceCapability(
|
||||
*capability
|
||||
)
|
||||
with patch("vllm.platforms.current_platform", mock_platform):
|
||||
yield mock_platform
|
||||
|
||||
return _mock_platform
|
||||
@@ -94,7 +94,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
run_model(full_compilation_config, model_name, **model_kwargs)
|
||||
|
||||
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
||||
assert num_compile_ranges in [1, 2]
|
||||
assert num_compile_ranges in [1, 2, 3]
|
||||
|
||||
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
|
||||
print("Fusion results:")
|
||||
@@ -107,12 +107,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
# Now check the matches
|
||||
for match_name in matches_check:
|
||||
num_ranges_activated = (
|
||||
1 if match_name == "ar_rms_fusion" else num_compile_ranges
|
||||
)
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
|
||||
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
|
||||
|
||||
# AR+RMS skips the largest range; SP skips the smallest.
|
||||
# When both are enabled, AR+RMS activation count is
|
||||
# model-dependent (hidden_size affects threshold), so derive
|
||||
# from log data.
|
||||
if (
|
||||
match_name == "ar_rms_fusion"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
assert (
|
||||
len(log_matches) >= tp_size and len(log_matches) % tp_size == 0
|
||||
), (
|
||||
f"Expected multiple of {tp_size} ar_rms log entries, "
|
||||
f"found {len(log_matches)}"
|
||||
)
|
||||
num_ranges_activated = len(log_matches) // tp_size
|
||||
elif (
|
||||
match_name in ("ar_rms_fusion", "sequence_parallel")
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
num_ranges_activated = num_compile_ranges - 1
|
||||
else:
|
||||
num_ranges_activated = num_compile_ranges
|
||||
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
assert len(log_matches) == n_expected, (
|
||||
f"Could not find {n_expected} {match_name} "
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
@@ -122,8 +143,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
|
||||
# AR+rms+quant takes precedence over rms+quant if activated.
|
||||
# That means we get full matching where ar+rms+quant was not activated,
|
||||
# and less where it was
|
||||
# That means we get full matching where ar+rms+quant was not
|
||||
# activated, and less where it was (only the smallest range).
|
||||
assert sum(m == expected_matches for m in log_matches) == tp_size * (
|
||||
num_ranges_activated - 1
|
||||
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
|
||||
@@ -135,6 +156,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
|
||||
f"where ar+rms+quant was activated"
|
||||
)
|
||||
elif (
|
||||
match_name == "async_tp"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
# AsyncTP only finds patterns on ranges where SP ran.
|
||||
n_sp_ranges = num_compile_ranges - 1
|
||||
assert (
|
||||
sum(m == expected_matches for m in log_matches)
|
||||
== tp_size * n_sp_ranges
|
||||
), (
|
||||
f"Expecting {expected_matches} async_tp on "
|
||||
f"{tp_size * n_sp_ranges} SP-range entries, "
|
||||
f"found: {log_matches}"
|
||||
)
|
||||
assert sum(m == 0 for m in log_matches) == tp_size, (
|
||||
f"Expecting 0 async_tp on {tp_size} small-range entries "
|
||||
f"(no SP), found: {log_matches}"
|
||||
)
|
||||
elif (
|
||||
match_name == "ar_rms_fusion"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
# SP consumes allreduce patterns first, so AR+RMS finds
|
||||
# full matches only on the smallest range (no SP).
|
||||
assert sum(m == expected_matches for m in log_matches) == tp_size, (
|
||||
f"Expecting {expected_matches} ar_rms on "
|
||||
f"{tp_size} small-range entries, found: {log_matches}"
|
||||
)
|
||||
assert sum(m == 0 for m in log_matches) == tp_size * (
|
||||
num_ranges_activated - 1
|
||||
), (
|
||||
f"Expecting 0 ar_rms on "
|
||||
f"{tp_size * (num_ranges_activated - 1)} large-range "
|
||||
f"entries (SP took precedence), found: {log_matches}"
|
||||
)
|
||||
else:
|
||||
expected_matches_list = [expected_matches] * n_expected
|
||||
assert sorted(log_matches) == expected_matches_list, (
|
||||
@@ -142,7 +200,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"found: {sorted(log_matches)}"
|
||||
)
|
||||
|
||||
if match_name == "ar_rms_fusion":
|
||||
if match_name == "ar_rms_fusion" and num_compile_ranges >= 2:
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping "
|
||||
r".*AllReduceFusionPass.* with compile range",
|
||||
@@ -155,4 +213,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
if match_name == "sequence_parallel" and num_compile_ranges >= 2:
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping "
|
||||
r".*SequenceParallelismPass.* with compile range",
|
||||
log_holder.text,
|
||||
)
|
||||
|
||||
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
|
||||
assert len(log_matches) == n_expected, (
|
||||
f'Could not find {n_expected} "Skipping SequenceParallelismPass" '
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
@@ -66,6 +66,9 @@ def test_tp2_async_tp_fp8_fusions(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=False,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -123,6 +126,9 @@ def test_tp2_async_tp_fusions(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=False,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -141,3 +147,130 @@ def test_tp2_async_tp_fusions(
|
||||
matches_check,
|
||||
tp_size=2,
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
[llama3_8b_fp8, llama4_scout_fp8],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
def test_tp2_sp_ar_rms_fp8_fusions(
|
||||
model_name: str,
|
||||
matches_fn: Callable[[int], Matches],
|
||||
model_kwargs: dict,
|
||||
hf_overrides: Callable[[int], dict],
|
||||
attn_backend: AttentionBackendCase,
|
||||
n_layers: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
run_e2e_fusion_test,
|
||||
monkeypatch,
|
||||
):
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
if is_blackwell():
|
||||
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
|
||||
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
|
||||
|
||||
# Reduce size of model and skip weight loading time
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
custom_ops=custom_ops.split(","),
|
||||
pass_config=PassConfig(
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
fuse_attn_quant=True,
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=True,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
matches_check = [
|
||||
"rms_quant_fusion",
|
||||
"act_quant_fusion",
|
||||
"norm_rope_fusion",
|
||||
"attn_quant_fusion",
|
||||
"ar_rms_fusion",
|
||||
"sequence_parallel",
|
||||
"async_tp",
|
||||
]
|
||||
|
||||
run_e2e_fusion_test(
|
||||
model_name,
|
||||
matches,
|
||||
model_kwargs,
|
||||
attn_backend,
|
||||
compilation_config,
|
||||
matches_check,
|
||||
tp_size=2,
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
[llama3_8b, qwen3_a3b],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
def test_tp2_sp_ar_rms_fusions(
|
||||
model_name: str,
|
||||
matches_fn: Callable[[int], Matches],
|
||||
model_kwargs: dict,
|
||||
hf_overrides: Callable[[int], dict],
|
||||
attn_backend: AttentionBackendCase,
|
||||
n_layers: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
run_e2e_fusion_test,
|
||||
):
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
# Reduce size of model and skip weight loading time
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
custom_ops=custom_ops.split(","),
|
||||
pass_config=PassConfig(
|
||||
enable_qk_norm_rope_fusion=True,
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_allreduce_rms=True,
|
||||
# Override threshold for testing (models have small hidden_size)
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
)
|
||||
|
||||
matches_check = [
|
||||
"norm_rope_fusion",
|
||||
"ar_rms_fusion",
|
||||
"sequence_parallel",
|
||||
"async_tp",
|
||||
]
|
||||
|
||||
run_e2e_fusion_test(
|
||||
model_name,
|
||||
matches,
|
||||
model_kwargs,
|
||||
attn_backend,
|
||||
compilation_config,
|
||||
matches_check,
|
||||
tp_size=2,
|
||||
)
|
||||
|
||||
@@ -421,6 +421,7 @@ def test_cudagraph_sizes_post_init(
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
eliminate_noops=True,
|
||||
sp_min_token_num=512 if enable_sp else None,
|
||||
),
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
)
|
||||
|
||||
110
tests/compile/test_sequence_parallelism_threshold.py
Normal file
110
tests/compile/test_sequence_parallelism_threshold.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
SP_MIN_HIDDEN_SIZE,
|
||||
SP_MIN_PER_GPU_SIZE_MB,
|
||||
get_sequence_parallelism_threshold,
|
||||
)
|
||||
|
||||
|
||||
class TestGetSequenceParallelismThreshold:
|
||||
"""Tests for get_sequence_parallelism_threshold function."""
|
||||
|
||||
def test_non_cuda_returns_none(self, mock_cuda_platform):
|
||||
"""Non-CUDA platforms should return None."""
|
||||
with mock_cuda_platform(is_cuda=False):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=8192, tp_size=2, element_size=2
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_unsupported_device_capability_returns_none(self, mock_cuda_platform):
|
||||
"""Unsupported device capabilities (e.g., sm80) should return None."""
|
||||
with mock_cuda_platform(capability=(8, 0)):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=8192, tp_size=2, element_size=2
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_small_hidden_size_returns_none(self, mock_cuda_platform):
|
||||
"""H100 with hidden_size below threshold should return None."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=4096,
|
||||
tp_size=2,
|
||||
element_size=2, # 4096 < 8192
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_h100_large_model_returns_threshold(self, mock_cuda_platform):
|
||||
"""H100 with large enough hidden_size should return calculated threshold."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
hidden_size = 8192
|
||||
tp_size = 2
|
||||
element_size = 2 # float16/bfloat16
|
||||
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=hidden_size,
|
||||
tp_size=tp_size,
|
||||
element_size=element_size,
|
||||
)
|
||||
|
||||
# Verify calculation: (8 * 2 * 1024 * 1024) // (8192 * 2) = 1024
|
||||
MiB = 1024 * 1024
|
||||
expected = int(
|
||||
(SP_MIN_PER_GPU_SIZE_MB[90] * tp_size * MiB)
|
||||
// (hidden_size * element_size)
|
||||
)
|
||||
assert result == expected
|
||||
assert result == 1024
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size,tp_size,element_size,expected",
|
||||
[
|
||||
# Boundary: exactly at min hidden size threshold, tp_size=1
|
||||
# (8 * 1 * 1024 * 1024) // (8192 * 2) = 512
|
||||
(8192, 1, 2, 512),
|
||||
# Larger hidden size reduces token threshold
|
||||
# (8 * 1 * 1024 * 1024) // (16384 * 2) = 256
|
||||
(16384, 1, 2, 256),
|
||||
# Larger tp_size increases token threshold
|
||||
# (8 * 4 * 1024 * 1024) // (8192 * 2) = 2048
|
||||
(8192, 4, 2, 2048),
|
||||
# Larger element_size (fp32) reduces token threshold
|
||||
# (8 * 2 * 1024 * 1024) // (8192 * 4) = 512
|
||||
(8192, 2, 4, 512),
|
||||
],
|
||||
)
|
||||
def test_threshold_calculation_variations(
|
||||
self, mock_cuda_platform, hidden_size, tp_size, element_size, expected
|
||||
):
|
||||
"""Test threshold calculation with various parameter combinations."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=hidden_size,
|
||||
tp_size=tp_size,
|
||||
element_size=element_size,
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_hidden_size_boundary(self, mock_cuda_platform):
|
||||
"""Test behavior at the exact hidden_size boundary."""
|
||||
with mock_cuda_platform(capability=(9, 0)):
|
||||
# Just below threshold
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=SP_MIN_HIDDEN_SIZE[90] - 1,
|
||||
tp_size=2,
|
||||
element_size=2,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Exactly at threshold
|
||||
result = get_sequence_parallelism_threshold(
|
||||
hidden_size=SP_MIN_HIDDEN_SIZE[90],
|
||||
tp_size=2,
|
||||
element_size=2,
|
||||
)
|
||||
assert result is not None
|
||||
Reference in New Issue
Block a user