[compile] Enable sequence parallelism matching w/o custom ops enabled (#27126)

Signed-off-by: angelayi <yiangela7@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
Angela Yi
2025-11-15 03:46:12 -08:00
committed by GitHub
parent 173b356abf
commit f36292dbee
6 changed files with 472 additions and 444 deletions

View File

@@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test
is_blackwell = lambda: current_platform.is_device_capability(100)
"""Are we running on Blackwell, a lot of tests depend on it"""
class Matches(NamedTuple):
attention_fusion: int = 0
allreduce_fusion: int = 0
sequence_parallel: int = 0
async_tp: int = 0
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: AttentionBackendEnum
attention_fusions: int
allreduce_fusions: int | None = None
matches: Matches
MODELS_FP8: list[ModelBackendTestCase] = []
@@ -38,17 +47,33 @@ if current_platform.is_cuda():
ModelBackendTestCase(
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
# TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
# so FI attention+fp8_quant is at least tested once
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER
if is_blackwell()
else AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
# https://github.com/vllm-project/vllm/issues/28568
# TODO FlashInfer attn broken on Blackwell for llama4:
# https://github.com/vllm-project/vllm/issues/28604
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=48,
allreduce_fusion=96,
sequence_parallel=96,
async_tp=95, # mlp is moe, no fusion there
),
),
]
@@ -57,8 +82,12 @@ if current_platform.is_cuda():
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=32,
allreduce_fusions=65,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
]
@@ -68,15 +97,23 @@ if current_platform.is_cuda():
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
matches=Matches(
attention_fusion=0,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
),
ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=97,
matches=Matches(
attention_fusion=0,
allreduce_fusion=97,
sequence_parallel=97,
async_tp=96, # MLP is MoE, half the fusions of dense
),
),
]
@@ -86,19 +123,19 @@ elif current_platform.is_rocm():
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32,
matches=Matches(attention_fusion=32),
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
matches=Matches(attention_fusion=32),
),
]
@@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
"model_name, model_kwargs, backend, matches, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
@@ -118,15 +154,14 @@ def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
not is_blackwell() or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
@@ -169,12 +204,12 @@ def test_attn_quant(
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
matches = re.findall(
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions
assert len(log_matches) == 1, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
@@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
@@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
@@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
@@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
matches = re.findall(
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == attention_fusions
assert int(matches[1]) == attention_fusions
assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion
matches = re.findall(
log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == allreduce_fusions
assert int(matches[1]) == allreduce_fusions
assert int(log_matches[0]) == matches.allreduce_fusion
assert int(log_matches[1]) == matches.allreduce_fusion
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="sequence parallel only tested on CUDA",
)
def test_tp2_attn_quant_async_tp(
model_name: str,
model_kwargs: dict,
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if is_blackwell():
# TODO: https://github.com/vllm-project/vllm/issues/27893
pytest.skip("Blackwell is not supported for AsyncTP pass")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER:
if not has_flashinfer():
pytest.skip("FlashInfer backend requires flashinfer installed")
if not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_sequence_parallelism=True,
enable_async_tp=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion
log_matches = re.findall(
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.sequence_parallel
assert int(log_matches[1]) == matches.sequence_parallel
log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.async_tp
assert int(log_matches[1]) == matches.async_tp
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):