[torch.compile] Sequence Parallelism threshold compile ranges (#28672)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: Jason Li <jasonlizhengjian@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -94,7 +94,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
run_model(full_compilation_config, model_name, **model_kwargs)
|
||||
|
||||
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
||||
assert num_compile_ranges in [1, 2]
|
||||
assert num_compile_ranges in [1, 2, 3]
|
||||
|
||||
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
|
||||
print("Fusion results:")
|
||||
@@ -107,12 +107,33 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
# Now check the matches
|
||||
for match_name in matches_check:
|
||||
num_ranges_activated = (
|
||||
1 if match_name == "ar_rms_fusion" else num_compile_ranges
|
||||
)
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
|
||||
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
|
||||
|
||||
# AR+RMS skips the largest range; SP skips the smallest.
|
||||
# When both are enabled, AR+RMS activation count is
|
||||
# model-dependent (hidden_size affects threshold), so derive
|
||||
# from log data.
|
||||
if (
|
||||
match_name == "ar_rms_fusion"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
assert (
|
||||
len(log_matches) >= tp_size and len(log_matches) % tp_size == 0
|
||||
), (
|
||||
f"Expected multiple of {tp_size} ar_rms log entries, "
|
||||
f"found {len(log_matches)}"
|
||||
)
|
||||
num_ranges_activated = len(log_matches) // tp_size
|
||||
elif (
|
||||
match_name in ("ar_rms_fusion", "sequence_parallel")
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
num_ranges_activated = num_compile_ranges - 1
|
||||
else:
|
||||
num_ranges_activated = num_compile_ranges
|
||||
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
assert len(log_matches) == n_expected, (
|
||||
f"Could not find {n_expected} {match_name} "
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
@@ -122,8 +143,8 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
|
||||
# AR+rms+quant takes precedence over rms+quant if activated.
|
||||
# That means we get full matching where ar+rms+quant was not activated,
|
||||
# and less where it was
|
||||
# That means we get full matching where ar+rms+quant was not
|
||||
# activated, and less where it was (only the smallest range).
|
||||
assert sum(m == expected_matches for m in log_matches) == tp_size * (
|
||||
num_ranges_activated - 1
|
||||
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
|
||||
@@ -135,6 +156,43 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
|
||||
f"where ar+rms+quant was activated"
|
||||
)
|
||||
elif (
|
||||
match_name == "async_tp"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
# AsyncTP only finds patterns on ranges where SP ran.
|
||||
n_sp_ranges = num_compile_ranges - 1
|
||||
assert (
|
||||
sum(m == expected_matches for m in log_matches)
|
||||
== tp_size * n_sp_ranges
|
||||
), (
|
||||
f"Expecting {expected_matches} async_tp on "
|
||||
f"{tp_size * n_sp_ranges} SP-range entries, "
|
||||
f"found: {log_matches}"
|
||||
)
|
||||
assert sum(m == 0 for m in log_matches) == tp_size, (
|
||||
f"Expecting 0 async_tp on {tp_size} small-range entries "
|
||||
f"(no SP), found: {log_matches}"
|
||||
)
|
||||
elif (
|
||||
match_name == "ar_rms_fusion"
|
||||
and "sequence_parallel" in matches_check
|
||||
and num_compile_ranges >= 2
|
||||
):
|
||||
# SP consumes allreduce patterns first, so AR+RMS finds
|
||||
# full matches only on the smallest range (no SP).
|
||||
assert sum(m == expected_matches for m in log_matches) == tp_size, (
|
||||
f"Expecting {expected_matches} ar_rms on "
|
||||
f"{tp_size} small-range entries, found: {log_matches}"
|
||||
)
|
||||
assert sum(m == 0 for m in log_matches) == tp_size * (
|
||||
num_ranges_activated - 1
|
||||
), (
|
||||
f"Expecting 0 ar_rms on "
|
||||
f"{tp_size * (num_ranges_activated - 1)} large-range "
|
||||
f"entries (SP took precedence), found: {log_matches}"
|
||||
)
|
||||
else:
|
||||
expected_matches_list = [expected_matches] * n_expected
|
||||
assert sorted(log_matches) == expected_matches_list, (
|
||||
@@ -142,7 +200,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"found: {sorted(log_matches)}"
|
||||
)
|
||||
|
||||
if match_name == "ar_rms_fusion":
|
||||
if match_name == "ar_rms_fusion" and num_compile_ranges >= 2:
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping "
|
||||
r".*AllReduceFusionPass.* with compile range",
|
||||
@@ -155,4 +213,17 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
if match_name == "sequence_parallel" and num_compile_ranges >= 2:
|
||||
log_matches = re.findall(
|
||||
r"pass_manager.py:\d+] Skipping "
|
||||
r".*SequenceParallelismPass.* with compile range",
|
||||
log_holder.text,
|
||||
)
|
||||
|
||||
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
|
||||
assert len(log_matches) == n_expected, (
|
||||
f'Could not find {n_expected} "Skipping SequenceParallelismPass" '
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
Reference in New Issue
Block a user