[torch.compile] Sequence Parallelism threshold compile ranges (#28672)

Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jason Li
2026-02-25 21:00:12 -08:00
committed by GitHub
parent 4171ff6dd9
commit 9d37941017
8 changed files with 524 additions and 32 deletions

View File

@@ -94,7 +94,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
run_model(full_compilation_config, model_name, **model_kwargs)
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2]
assert num_compile_ranges in [1, 2, 3]
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
print("Fusion results:")
@@ -107,12 +107,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
# Now check the matches
for match_name in matches_check:
num_ranges_activated = (
1 if match_name == "ar_rms_fusion" else num_compile_ranges
)
n_expected = tp_size * num_ranges_activated
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
# AR+RMS skips the largest range; SP skips the smallest.
# When both are enabled, AR+RMS activation count is
# model-dependent (hidden_size affects threshold), so derive
# from log data.
if (
match_name == "ar_rms_fusion"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
assert (
len(log_matches) >= tp_size and len(log_matches) % tp_size == 0
), (
f"Expected multiple of {tp_size} ar_rms log entries, "
f"found {len(log_matches)}"
)
num_ranges_activated = len(log_matches) // tp_size
elif (
match_name in ("ar_rms_fusion", "sequence_parallel")
and num_compile_ranges >= 2
):
num_ranges_activated = num_compile_ranges - 1
else:
num_ranges_activated = num_compile_ranges
n_expected = tp_size * num_ranges_activated
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
@@ -122,8 +143,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
# AR+rms+quant takes precedence over rms+quant if activated.
# That means we get full matching where ar+rms+quant was not activated,
# and less where it was
# That means we get full matching where ar+rms+quant was not
# activated, and less where it was (only the smallest range).
assert sum(m == expected_matches for m in log_matches) == tp_size * (
num_ranges_activated - 1
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
@@ -135,6 +156,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
f"where ar+rms+quant was activated"
)
elif (
match_name == "async_tp"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
# AsyncTP only finds patterns on ranges where SP ran.
n_sp_ranges = num_compile_ranges - 1
assert (
sum(m == expected_matches for m in log_matches)
== tp_size * n_sp_ranges
), (
f"Expecting {expected_matches} async_tp on "
f"{tp_size * n_sp_ranges} SP-range entries, "
f"found: {log_matches}"
)
assert sum(m == 0 for m in log_matches) == tp_size, (
f"Expecting 0 async_tp on {tp_size} small-range entries "
f"(no SP), found: {log_matches}"
)
elif (
match_name == "ar_rms_fusion"
and "sequence_parallel" in matches_check
and num_compile_ranges >= 2
):
# SP consumes allreduce patterns first, so AR+RMS finds
# full matches only on the smallest range (no SP).
assert sum(m == expected_matches for m in log_matches) == tp_size, (
f"Expecting {expected_matches} ar_rms on "
f"{tp_size} small-range entries, found: {log_matches}"
)
assert sum(m == 0 for m in log_matches) == tp_size * (
num_ranges_activated - 1
), (
f"Expecting 0 ar_rms on "
f"{tp_size * (num_ranges_activated - 1)} large-range "
f"entries (SP took precedence), found: {log_matches}"
)
else:
expected_matches_list = [expected_matches] * n_expected
assert sorted(log_matches) == expected_matches_list, (
@@ -142,7 +200,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"found: {sorted(log_matches)}"
)
if match_name == "ar_rms_fusion":
if match_name == "ar_rms_fusion" and num_compile_ranges >= 2:
log_matches = re.findall(
r"pass_manager.py:\d+] Skipping "
r".*AllReduceFusionPass.* with compile range",
@@ -155,4 +213,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
if match_name == "sequence_parallel" and num_compile_ranges >= 2:
log_matches = re.findall(
r"pass_manager.py:\d+] Skipping "
r".*SequenceParallelismPass.* with compile range",
log_holder.text,
)
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
assert len(log_matches) == n_expected, (
f'Could not find {n_expected} "Skipping SequenceParallelismPass" '
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
return run

View File

@@ -66,6 +66,9 @@ def test_tp2_async_tp_fp8_fusions(
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=False,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
@@ -123,6 +126,9 @@ def test_tp2_async_tp_fusions(
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=False,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
@@ -141,3 +147,130 @@ def test_tp2_async_tp_fusions(
matches_check,
tp_size=2,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp8, llama4_scout_fp8],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
def test_tp2_sp_ar_rms_fp8_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
model_kwargs: dict,
hf_overrides: Callable[[int], dict],
attn_backend: AttentionBackendCase,
n_layers: int,
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
monkeypatch,
):
matches = matches_fn(n_layers)
if is_blackwell():
# Disable FlashInfer scaled_mm FP8 as it's not supported in async tp patterns
monkeypatch.setenv("VLLM_DISABLED_KERNELS", "FlashInferFP8ScaledMMLinearKernel")
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
pass_config=PassConfig(
fuse_norm_quant=True,
fuse_act_quant=True,
fuse_attn_quant=True,
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
matches_check = [
"rms_quant_fusion",
"act_quant_fusion",
"norm_rope_fusion",
"attn_quant_fusion",
"ar_rms_fusion",
"sequence_parallel",
"async_tp",
]
run_e2e_fusion_test(
model_name,
matches,
model_kwargs,
attn_backend,
compilation_config,
matches_check,
tp_size=2,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b, qwen3_a3b],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
def test_tp2_sp_ar_rms_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
model_kwargs: dict,
hf_overrides: Callable[[int], dict],
attn_backend: AttentionBackendCase,
n_layers: int,
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
):
matches = matches_fn(n_layers)
# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
pass_config=PassConfig(
enable_qk_norm_rope_fusion=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)
matches_check = [
"norm_rope_fusion",
"ar_rms_fusion",
"sequence_parallel",
"async_tp",
]
run_e2e_fusion_test(
model_name,
matches,
model_kwargs,
attn_backend,
compilation_config,
matches_check,
tp_size=2,
)