diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 9c3e84af9..caaed14af 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -551,7 +551,7 @@ steps: - label: LoRA Test %N # 20min each timeout_in_minutes: 30 mirror_hardwares: [amdexperimental] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking source_file_dependencies: - vllm/lora @@ -647,7 +647,7 @@ steps: - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking source_file_dependencies: - csrc/attention/ @@ -662,7 +662,7 @@ steps: - label: Kernels Quantization Test %N # 64min timeout_in_minutes: 90 mirror_hardwares: [amdexperimental] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking source_file_dependencies: - csrc/quantization/ @@ -675,7 +675,7 @@ steps: - label: Kernels MoE Test %N # 40min timeout_in_minutes: 60 mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking source_file_dependencies: - csrc/quantization/cutlass_w8a8/moe/ @@ -753,7 +753,7 @@ steps: - label: Benchmarks # 11min timeout_in_minutes: 20 mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking working_dir: "/vllm-workspace/.buildkite" source_file_dependencies: @@ -764,7 +764,7 @@ steps: - label: Benchmarks CLI Test # 7min timeout_in_minutes: 20 mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking source_file_dependencies: - vllm/ @@ -838,7 +838,7 @@ steps: - label: Basic Models Tests (Extra Initialization) %N timeout_in_minutes: 45 mirror_hardwares: [amdexperimental, amdproduction] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking torch_nightly: true source_file_dependencies: @@ -900,7 +900,7 @@ steps: - label: Language Models Tests (Extra Standard) %N timeout_in_minutes: 45 mirror_hardwares: [amdexperimental] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking torch_nightly: true source_file_dependencies: @@ -921,7 +921,7 @@ steps: - label: Language Models Tests (Hybrid) %N timeout_in_minutes: 75 mirror_hardwares: [amdexperimental] - agent_pool: mi325_1 + agent_pool: mi325_8 # grade: Blocking torch_nightly: true source_file_dependencies: @@ -1190,16 +1190,16 @@ steps: - vllm/model_executor/layers/layernorm.py - vllm/model_executor/layers/activation.py - vllm/model_executor/layers/quantization/input_quant_fp8.py - - tests/compile/test_fusion_attn.py - - tests/compile/test_silu_mul_quant_fusion.py - - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/passes/test_fusion_attn.py + - tests/compile/passes/test_silu_mul_quant_fusion.py + - tests/compile/passes/distributed/test_fusion_all_reduce.py - tests/compile/fullgraph/test_full_graph.py commands: - nvidia-smi - - pytest -v -s tests/compile/test_fusion_attn.py - - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + - pytest -v -s tests/compile/passes/test_fusion_attn.py + - pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py # this runner has 2 GPUs available even though num_gpus=2 is not set - - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py # # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time # # Wrap with quotes to escape yaml @@ -1556,15 +1556,15 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py - - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py + - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py + - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm # - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" # Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293 # in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated. - - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - HIP_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=allgather_reducescatter --disable-nccl-for-dp-synchronization - pytest -v -s tests/v1/distributed/test_dbo.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e3146948b..1e28f520d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -519,6 +519,7 @@ steps: # However, find does not normally propagate error codes, so we combine it with xargs # (using -0 for proper path handling) - "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'" + - pytest -s -v compile/passes --ignore compile/passes/distributed - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 @@ -1080,14 +1081,14 @@ steps: - vllm/model_executor/layers/quantization/input_quant_fp8.py - tests/compile/test_fusion_attn.py - tests/compile/test_silu_mul_quant_fusion.py - - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/passes/distributed/test_fusion_all_reduce.py - tests/compile/fullgraph/test_full_graph.py commands: - nvidia-smi - pytest -v -s tests/compile/test_fusion_attn.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py # this runner has 2 GPUs available even though num_gpus=2 is not set - - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py # # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time # # Wrap with quotes to escape yaml # - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" @@ -1421,8 +1422,8 @@ steps: commands: - export VLLM_TEST_CLEAN_GPU_MEMORY=1 # Run sequence parallel tests - - pytest -v -s tests/distributed/test_sequence_parallel.py - - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py + - pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py + - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py - label: Distributed Tests (H100) # optional gpu: h100 @@ -1430,7 +1431,7 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: - - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py + - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py - pytest -v -s tests/distributed/test_context_parallel.py - VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput - pytest -v -s tests/v1/distributed/test_dbo.py diff --git a/.buildkite/test_areas/compile.yaml b/.buildkite/test_areas/compile.yaml index e8cf9e8bd..56fc011c7 100644 --- a/.buildkite/test_areas/compile.yaml +++ b/.buildkite/test_areas/compile.yaml @@ -2,7 +2,7 @@ group: Compile depends_on: - image-build steps: -- label: Sequence Parallel Tests (2 GPUs) +- label: Sequence Parallel Correctness Tests (2 GPUs) timeout_in_minutes: 50 working_dir: "/vllm-workspace/" num_devices: 2 @@ -11,12 +11,12 @@ steps: - vllm/compilation/ - vllm/v1/worker/ - vllm/v1/cudagraph_dispatcher.py - - tests/distributed/test_sequence_parallel.py + - tests/compile/correctness_e2e/test_sequence_parallel.py commands: - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - pytest -v -s tests/distributed/test_sequence_parallel.py + - pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py -- label: Sequence Parallel Tests (2xH100) +- label: Sequence Parallel Correctness Tests (2xH100) timeout_in_minutes: 50 working_dir: "/vllm-workspace/" device: h100 @@ -24,24 +24,30 @@ steps: num_devices: 2 commands: - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - pytest -v -s tests/distributed/test_sequence_parallel.py + - pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py + +- label: AsyncTP Correctness Tests (2xH100) + timeout_in_minutes: 50 + working_dir: "/vllm-workspace/" + device: h100 + optional: true + num_devices: 2 + commands: + - export VLLM_TEST_CLEAN_GPU_MEMORY=1 + - pytest -v -s tests/compile/correctness_e2e/test_async_tp.py - label: Distributed Compile Unit Tests (2xH100) - timeout_in_minutes: 40 + timeout_in_minutes: 20 working_dir: "/vllm-workspace/" device: h100 num_devices: 2 source_file_dependencies: - vllm/compilation/ - vllm/model_executor/layers - - tests/compile/distributed/test_fusion_all_reduce.py - - tests/compile/distributed/test_sequence_parallelism.py - - tests/compile/distributed/test_async_tp.py + - tests/compile/passes/distributed/ commands: - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py - - pytest -v -s tests/compile/distributed/test_sequence_parallelism.py - - pytest -v -s tests/compile/distributed/test_async_tp.py + - pytest -s -v tests/compile/passes/distributed - label: Fusion and Compile Unit Tests (B200) timeout_in_minutes: 20 @@ -55,17 +61,17 @@ steps: - vllm/model_executor/layers/attention/attention.py - vllm/v1/attention/backends/flashinfer.py - vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes - - tests/compile/test_fusion_attn.py - - tests/compile/test_silu_mul_quant_fusion.py - - tests/compile/distributed/test_fusion_all_reduce.py + - tests/compile/passes/test_fusion_attn.py + - tests/compile/passes/test_silu_mul_quant_fusion.py + - tests/compile/passes/distributed/test_fusion_all_reduce.py - tests/compile/fullgraph/test_full_graph.py commands: # b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell - nvidia-smi - - pytest -v -s tests/compile/test_fusion_attn.py -k FLASHINFER - - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + - pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER + - pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py # this runner has 2 GPUs available even though num_devices=2 is not set - - pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py + - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) # TODO(luka) move to H100 once pass tests run on H100 - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile diff --git a/.buildkite/test_areas/pytorch.yaml b/.buildkite/test_areas/pytorch.yaml index 1ac3eec58..97cb3cedc 100644 --- a/.buildkite/test_areas/pytorch.yaml +++ b/.buildkite/test_areas/pytorch.yaml @@ -3,7 +3,7 @@ depends_on: - image-build steps: - label: PyTorch Compilation Unit Tests - timeout_in_minutes: 30 + timeout_in_minutes: 10 source_file_dependencies: - vllm/ - tests/compile @@ -17,6 +17,14 @@ steps: # (using -0 for proper path handling) - "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'" +- label: PyTorch Compilation Passes Unit Tests + timeout_in_minutes: 20 + source_file_dependencies: + - vllm/ + - tests/compile/passes + commands: + - pytest -s -v compile/passes --ignore compile/passes/distributed + - label: PyTorch Fullgraph Smoke Test timeout_in_minutes: 35 source_file_dependencies: diff --git a/tests/compile/backend.py b/tests/compile/backend.py index fa4261900..ec4685324 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -11,10 +11,10 @@ from torch import fx from torch._ops import OpOverload from torch.fx._utils import lazy_format_graph_code -from vllm.compilation.fx_utils import find_op_nodes -from vllm.compilation.inductor_pass import InductorPass -from vllm.compilation.pass_manager import with_pattern_match_debug -from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.compilation.passes.fx_utils import find_op_nodes +from vllm.compilation.passes.inductor_pass import InductorPass +from vllm.compilation.passes.pass_manager import with_pattern_match_debug +from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger diff --git a/tests/compile/distributed/__init__.py b/tests/compile/correctness_e2e/__init__.py similarity index 100% rename from tests/compile/distributed/__init__.py rename to tests/compile/correctness_e2e/__init__.py diff --git a/tests/compile/correctness_e2e/test_async_tp.py b/tests/compile/correctness_e2e/test_async_tp.py new file mode 100644 index 000000000..cf9c75d91 --- /dev/null +++ b/tests/compile/correctness_e2e/test_async_tp.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest + +from tests.models.registry import HF_EXAMPLE_MODELS +from tests.utils import ( + compare_two_settings, + create_new_process_for_each_test, +) +from vllm.config import ( + CompilationMode, +) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize( + "model_id", + ["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"], +) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("async_tp_enabled", [True]) +@pytest.mark.parametrize("distributed_backend", ["mp"]) +@pytest.mark.parametrize("eager_mode", [False, True]) +def test_async_tp_pass_correctness( + model_id: str, + tp_size: int, + async_tp_enabled: bool, + distributed_backend: str, + eager_mode: bool, + num_gpus_available: int, +): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + + common_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if eager_mode: + common_args.append("--enforce-eager") + + compilation_config = { + "mode": CompilationMode.VLLM_COMPILE, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": {"fuse_gemm_comms": async_tp_enabled}, + } + + async_tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + "--compilation_config", + json.dumps(compilation_config), + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + compare_two_settings(model_id, async_tp_args, tp_args, method="generate") diff --git a/tests/distributed/test_sequence_parallel.py b/tests/compile/correctness_e2e/test_sequence_parallel.py similarity index 98% rename from tests/distributed/test_sequence_parallel.py rename to tests/compile/correctness_e2e/test_sequence_parallel.py index 0a7907aad..6c084f603 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/compile/correctness_e2e/test_sequence_parallel.py @@ -21,8 +21,8 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer -from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import compare_two_settings, create_new_process_for_each_test +from ...models.registry import HF_EXAMPLE_MODELS +from ...utils import compare_two_settings, create_new_process_for_each_test logger = init_logger("test_sequence_parallel") diff --git a/tests/compile/fusions_e2e/common.py b/tests/compile/fusions_e2e/common.py index d950bf5b6..284a9d66b 100644 --- a/tests/compile/fusions_e2e/common.py +++ b/tests/compile/fusions_e2e/common.py @@ -82,19 +82,17 @@ INDUCTOR_GRAPH_PARTITION = [ ] FUSION_LOG_PATTERNS: dict[str, re.Pattern] = { - "rms_quant_fusion": re.compile( - r"\[(?:compilation/)?fusion.py:\d+] Replaced (\d+) patterns" - ), - "act_quant_fusion": re.compile( - r"activation_quant_fusion.py:\d+] Replaced (\d+) patterns" - ), + "rms_quant_fusion": re.compile(r"rms_quant_fusion.py:\d+] Replaced (\d+) patterns"), + "act_quant_fusion": re.compile(r"act_quant_fusion.py:\d+] Replaced (\d+) patterns"), "norm_rope_fusion": re.compile( r"qk_norm_rope_fusion.py:\d+] Fused QK Norm\+RoPE on (\d+) sites" ), "attn_quant_fusion": re.compile( - r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes" + r"attn_quant_fusion.py:\d+] Fused quant onto (\d+) attention nodes" + ), + "ar_rms_fusion": re.compile( + r"allreduce_rms_fusion.py:\d+] Replaced (\d+) patterns" ), - "ar_rms_fusion": re.compile(r"collective_fusion.py:\d+] Replaced (\d+) patterns"), "sequence_parallel": re.compile( r"sequence_parallelism.py:\d+] Replaced (\d+) patterns" ), diff --git a/tests/compile/passes/__init__.py b/tests/compile/passes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compile/passes/distributed/__init__.py b/tests/compile/passes/distributed/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compile/distributed/test_async_tp.py b/tests/compile/passes/distributed/test_async_tp.py similarity index 84% rename from tests/compile/distributed/test_async_tp.py rename to tests/compile/passes/distributed/test_async_tp.py index 3b96fa65d..df7747d1a 100644 --- a/tests/compile/distributed/test_async_tp.py +++ b/tests/compile/passes/distributed/test_async_tp.py @@ -1,16 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json import pytest import torch import vllm.envs as envs -from vllm.compilation.collective_fusion import AsyncTPPass +from tests.compile.backend import TestBackend +from tests.utils import ( + multi_gpu_test, +) +from vllm.compilation.passes.fusion.collective_fusion import AsyncTPPass from vllm.config import ( CompilationConfig, - CompilationMode, DeviceConfig, ModelConfig, PassConfig, @@ -29,14 +31,6 @@ from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed -from ...models.registry import HF_EXAMPLE_MODELS -from ...utils import ( - compare_two_settings, - create_new_process_for_each_test, - multi_gpu_test, -) -from ..backend import TestBackend - FP8_DTYPE = current_platform.fp8_dtype() prompts = [ @@ -377,67 +371,3 @@ def async_tp_pass_on_test_model( # In post-nodes, fused_matmul_reduce_scatter or \ # fused_all_gather_matmul should exist backend.check_after_ops(model.ops_in_model_after()) - - -@create_new_process_for_each_test() -@pytest.mark.parametrize( - "model_id", - ["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"], -) -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("async_tp_enabled", [True]) -@pytest.mark.parametrize("distributed_backend", ["mp"]) -@pytest.mark.parametrize("eager_mode", [False, True]) -def test_async_tp_pass_correctness( - model_id: str, - tp_size: int, - async_tp_enabled: bool, - distributed_backend: str, - eager_mode: bool, - num_gpus_available: int, -): - model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) - model_info.check_transformers_version(on_fail="skip") - model_info.check_available_online(on_fail="skip") - - pp_size = 1 - if num_gpus_available < tp_size: - pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") - - common_args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "2048", - "--max-num-seqs", - "8", - ] - if eager_mode: - common_args.append("--enforce-eager") - - compilation_config = { - "mode": CompilationMode.VLLM_COMPILE, - "compile_sizes": [2, 4, 8], - "splitting_ops": [], - "pass_config": {"fuse_gemm_comms": async_tp_enabled}, - } - - async_tp_args = [ - *common_args, - "--tensor-parallel-size", - str(tp_size), - "--distributed-executor-backend", - distributed_backend, - "--compilation_config", - json.dumps(compilation_config), - ] - - tp_args = [ - *common_args, - "--tensor-parallel-size", - str(tp_size), - "--distributed-executor-backend", - "mp", - ] - - compare_two_settings(model_id, async_tp_args, tp_args, method="generate") diff --git a/tests/compile/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py similarity index 95% rename from tests/compile/distributed/test_fusion_all_reduce.py rename to tests/compile/passes/distributed/test_fusion_all_reduce.py index d2d90adae..f13f49b67 100644 --- a/tests/compile/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -6,11 +6,15 @@ import pytest import torch import vllm.envs as envs +from tests.compile.backend import TestBackend +from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.collective_fusion import AllReduceFusionPass -from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass +from vllm.compilation.passes.utility.fix_functionalization import ( + FixFunctionalizationPass, +) +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, CompilationMode, @@ -33,9 +37,6 @@ from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed -from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test -from ..backend import TestBackend - class TestAllReduceRMSNormModel(torch.nn.Module): def __init__(self, hidden_size=16, token_num=16, eps=1e-6): diff --git a/tests/compile/distributed/test_sequence_parallelism.py b/tests/compile/passes/distributed/test_sequence_parallelism.py similarity index 94% rename from tests/compile/distributed/test_sequence_parallelism.py rename to tests/compile/passes/distributed/test_sequence_parallelism.py index d8a1a4288..46363a9a4 100644 --- a/tests/compile/distributed/test_sequence_parallelism.py +++ b/tests/compile/passes/distributed/test_sequence_parallelism.py @@ -5,12 +5,14 @@ import pytest import torch import vllm.envs as envs -from vllm.compilation.fusion import RMSNormQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.compilation.sequence_parallelism import SequenceParallelismPass -from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from tests.compile.backend import TestBackend +from tests.utils import TestFP8Layer, multi_gpu_test +from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass +from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass +from vllm.compilation.passes.fx_utils import find_auto_fn +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass +from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass from vllm.config import ( CompilationConfig, CUDAGraphMode, @@ -34,9 +36,6 @@ from vllm.platforms import current_platform from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed -from ...utils import TestFP8Layer, multi_gpu_test -from ..backend import TestBackend - FP8_DTYPE = current_platform.fp8_dtype() prompts = [ "Hello, my name is", diff --git a/tests/compile/test_functionalization.py b/tests/compile/passes/test_functionalization.py similarity index 93% rename from tests/compile/test_functionalization.py rename to tests/compile/passes/test_functionalization.py index 9791493fd..e8da56b26 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/passes/test_functionalization.py @@ -5,12 +5,18 @@ import pytest import torch import vllm.envs as envs -from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass -from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import RMSNormQuantFusionPass -from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass +from tests.compile.backend import TestBackend +from tests.utils import TestFP8Layer +from vllm.compilation.passes.fusion.act_quant_fusion import ( + ActivationQuantFusionPass, +) +from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass +from vllm.compilation.passes.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func +from vllm.compilation.passes.utility.fix_functionalization import ( + FixFunctionalizationPass, +) +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, ModelConfig, @@ -26,9 +32,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform -from ..utils import TestFP8Layer -from .backend import TestBackend - TEST_FP8 = current_platform.supports_fp8() FP8_DTYPE = current_platform.fp8_dtype() diff --git a/tests/compile/test_fuse_act_padding.py b/tests/compile/passes/test_fuse_act_padding.py similarity index 93% rename from tests/compile/test_fuse_act_padding.py rename to tests/compile/passes/test_fuse_act_padding.py index d2670cd64..f3f3bda47 100644 --- a/tests/compile/test_fuse_act_padding.py +++ b/tests/compile/passes/test_fuse_act_padding.py @@ -6,9 +6,10 @@ import pytest import torch import vllm.config +from tests.compile.backend import TestBackend from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, CompilationMode, @@ -19,8 +20,6 @@ from vllm.config import ( from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.utils import rocm_unquantized_gemm -from .backend import TestBackend - class TestModel(torch.nn.Module): def __init__( @@ -95,7 +94,7 @@ def test_fuse_act_padding( ) with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: - from vllm.compilation.rocm_aiter_fusion import ( + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( RocmAiterTritonAddRMSNormPadFusionPass, ) diff --git a/tests/compile/test_fusion.py b/tests/compile/passes/test_fusion.py similarity index 94% rename from tests/compile/test_fusion.py rename to tests/compile/passes/test_fusion.py index e4a4fef23..a2128150f 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/passes/test_fusion.py @@ -7,12 +7,18 @@ import torch import vllm.config import vllm.plugins +from tests.compile.backend import TestBackend +from tests.utils import TestBlockFP8Layer, TestFP8Layer from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops -from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass -from vllm.compilation.fx_utils import find_op_nodes -from vllm.compilation.matcher_utils import QUANT_OPS -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS +from vllm.compilation.passes.fusion.rms_quant_fusion import ( + FUSED_OPS, + FusedRMSQuantKey, + RMSNormQuantFusionPass, +) +from vllm.compilation.passes.fx_utils import find_op_nodes +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, CompilationMode, @@ -51,9 +57,6 @@ from vllm.utils.deep_gemm import ( is_deep_gemm_supported, ) -from ..utils import TestBlockFP8Layer, TestFP8Layer -from .backend import TestBackend - FP8_DTYPE = current_platform.fp8_dtype() RMS_OP = torch.ops._C.rms_norm.default @@ -223,7 +226,7 @@ class TestModel(torch.nn.Module): if self.use_aiter_fusion: if self.group_shape.is_per_group(): # Blockwise aiter fusion - from vllm.compilation.rocm_aiter_fusion import ( + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( AiterFusedAddRMSFp8GroupQuantPattern, AiterRMSFp8GroupQuantPattern, ) @@ -234,7 +237,7 @@ class TestModel(torch.nn.Module): ] else: # Per-token aiter fusion - from vllm.compilation.rocm_aiter_fusion import ( + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( AiterFusedAddRMSNormDynamicQuantPattern, AiterRMSNormDynamicQuantPattern, ) @@ -410,7 +413,9 @@ def test_aiter_fusion_rmsnorm_quant( ) with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: - from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormQuantFusionPass + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) m.setenv("VLLM_ROCM_USE_AITER", "1") diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py similarity index 97% rename from tests/compile/test_fusion_attn.py rename to tests/compile/passes/test_fusion_attn.py index 6515c5222..75d5c42f0 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -6,14 +6,14 @@ import pytest import torch._dynamo from tests.compile.backend import LazyInitPass, TestBackend -from tests.utils import flat_product +from tests.utils import TestFP8Layer, flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass -from vllm.compilation.fx_utils import find_op_nodes -from vllm.compilation.matcher_utils import QUANT_OPS -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass +from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS +from vllm.compilation.passes.fx_utils import find_op_nodes +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import ( AttentionConfig, CacheConfig, @@ -38,8 +38,6 @@ from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.kv_cache_interface import AttentionSpec -from ..utils import TestFP8Layer - FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/passes/test_noop_elimination.py similarity index 97% rename from tests/compile/test_noop_elimination.py rename to tests/compile/passes/test_noop_elimination.py index 02bc40230..412e8056f 100644 --- a/tests/compile/test_noop_elimination.py +++ b/tests/compile/passes/test_noop_elimination.py @@ -5,11 +5,10 @@ import pytest import torch import vllm -from vllm.compilation.noop_elimination import NoOpEliminationPass +from tests.compile.backend import TestBackend +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig -from .backend import TestBackend - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) # Important edge case is when `num_tokens == buffer_size` diff --git a/tests/compile/test_pass_manager.py b/tests/compile/passes/test_pass_manager.py similarity index 95% rename from tests/compile/test_pass_manager.py rename to tests/compile/passes/test_pass_manager.py index df8e5b69f..9ba989228 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/passes/test_pass_manager.py @@ -5,12 +5,12 @@ import copy import pytest import torch -from vllm.compilation.inductor_pass import ( +from vllm.compilation.passes.inductor_pass import ( CallableInductorPass, InductorPass, pass_context, ) -from vllm.compilation.pass_manager import PostGradPassManager +from vllm.compilation.passes.pass_manager import PostGradPassManager from vllm.config import ModelConfig, VllmConfig from vllm.config.utils import Range diff --git a/tests/compile/test_qk_norm_rope_fusion.py b/tests/compile/passes/test_qk_norm_rope_fusion.py similarity index 95% rename from tests/compile/test_qk_norm_rope_fusion.py rename to tests/compile/passes/test_qk_norm_rope_fusion.py index 19511b787..bb8bc043e 100644 --- a/tests/compile/test_qk_norm_rope_fusion.py +++ b/tests/compile/passes/test_qk_norm_rope_fusion.py @@ -5,13 +5,17 @@ import pytest import torch from tests.compile.backend import TestBackend -from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.compilation.qk_norm_rope_fusion import ( +from vllm.compilation.passes.fusion.matcher_utils import ( + FLASHINFER_ROTARY_OP, + RMS_OP, + ROTARY_OP, +) +from vllm.compilation.passes.fusion.qk_norm_rope_fusion import ( FUSED_QK_ROPE_OP, QKNormRoPEFusionPass, ) +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, CompilationMode, diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/passes/test_silu_mul_quant_fusion.py similarity index 94% rename from tests/compile/test_silu_mul_quant_fusion.py rename to tests/compile/passes/test_silu_mul_quant_fusion.py index dec5ca8de..c5ef01501 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/passes/test_silu_mul_quant_fusion.py @@ -6,17 +6,19 @@ import pytest import torch import vllm.envs as envs +from tests.compile.backend import TestBackend from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor +from tests.utils import TestFP8Layer from vllm._aiter_ops import IS_AITER_FOUND from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -from vllm.compilation.activation_quant_fusion import ( +from vllm.compilation.passes.fusion.act_quant_fusion import ( FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass, ) -from vllm.compilation.fusion import QUANT_OPS -from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.compilation.post_cleanup import PostCleanupPass +from vllm.compilation.passes.fusion.rms_quant_fusion import QUANT_OPS +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, CompilationMode, @@ -48,9 +50,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform -from ..utils import TestFP8Layer -from .backend import TestBackend - FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -100,7 +99,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module): class TestSiluMulNvfp4QuantModel(torch.nn.Module): def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): super().__init__() - from vllm.compilation.activation_quant_fusion import ( + from vllm.compilation.passes.fusion.act_quant_fusion import ( silu_and_mul_nvfp4_quant_supported, ) @@ -239,7 +238,7 @@ def test_fusion_silu_and_mul_quant( with set_current_vllm_config(config): fusion_passes = [ActivationQuantFusionPass(config)] if IS_AITER_FOUND: - from vllm.compilation.rocm_aiter_fusion import ( + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( RocmAiterSiluMulFp8GroupQuantFusionPass, ) diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py index 14ae8233f..c90454ed0 100644 --- a/tests/compile/test_compile_ranges.py +++ b/tests/compile/test_compile_ranges.py @@ -10,7 +10,7 @@ from torch import nn import tests.compile.silly_attention # noqa from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.inductor_pass import ( +from vllm.compilation.passes.inductor_pass import ( InductorPass, get_pass_context, ) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index f1170b1b8..eb2f0669e 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -8,7 +8,9 @@ import pytest from pydantic import ValidationError from vllm.compilation.counter import compilation_counter -from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.passes.utility.fix_functionalization import ( + FixFunctionalizationPass, +) from vllm.config import ( CompilationConfig, CUDAGraphMode, diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 38e3e038a..6d1e2daf9 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -8,7 +8,7 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx from vllm.compilation.backends import split_graph -from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.passes.fx_utils import find_op_nodes # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 85833a7a8..ce2da3cf2 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -22,11 +22,6 @@ from torch._dispatch.python import enable_python_dispatcher from torch._logging._internal import trace_structured import vllm.envs as envs -from vllm.compilation.inductor_pass import pass_context -from vllm.compilation.partition_rules import ( - inductor_partition_rule_context, - should_split, -) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import DynamicShapesType from vllm.config.utils import Range, hash_factors @@ -44,8 +39,12 @@ from .compiler_interface import ( is_compile_cache_enabled, ) from .counter import compilation_counter -from .inductor_pass import InductorPass -from .pass_manager import PostGradPassManager +from .partition_rules import ( + inductor_partition_rule_context, + should_split, +) +from .passes.inductor_pass import InductorPass, pass_context +from .passes.pass_manager import PostGradPassManager logger = init_logger(__name__) diff --git a/vllm/compilation/passes/__init__.py b/vllm/compilation/passes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/compilation/passes/fusion/__init__.py b/vllm/compilation/passes/fusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py similarity index 97% rename from vllm/compilation/activation_quant_fusion.py rename to vllm/compilation/passes/fusion/act_quant_fusion.py index 1eb23bf03..e14100384 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform -from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 -from .inductor_pass import enable_fake_mode +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul -from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 logger = init_logger(__name__) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py similarity index 68% rename from vllm/compilation/collective_fusion.py rename to vllm/compilation/passes/fusion/allreduce_rms_fusion.py index d7514a170..0b343fd16 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -8,7 +8,6 @@ import torch._inductor.pattern_matcher as pm import torch.fx as fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass -from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig from vllm.config.utils import Range @@ -24,12 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op -from .inductor_pass import enable_fake_mode +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm -from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() +logger = init_logger(__name__) + flashinfer_comm: ModuleType | None = None if find_spec("flashinfer"): try: @@ -45,406 +46,6 @@ logger = init_logger(__name__) if hasattr(torch.ops._C, "scaled_fp4_quant"): STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default - -class BasePattern: - def __init__(self, dtype: torch.dtype, device: str | None) -> None: - self.dtype = dtype - self.device = device - self.tp = get_tp_group() - self.tp_size = get_tensor_model_parallel_world_size() - - -class GEMMReduceScatterPattern(BasePattern): - def get_inputs(self) -> list[torch.Tensor]: - mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) - mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - return [mul, mm_weight] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: - mm = torch.ops.aten.mm.default(mul, mm_weight) - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - mm, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name, - ) - return reduce_scatter - - def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: - gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( - mul, - mm_weight, - "avg", - scatter_dim=0, - group_name=self.tp.device_group.group_name, - ) - - return gemm_rs - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class AllGatherGEMMPattern(BasePattern): - def get_inputs(self) -> list[torch.Tensor]: - x = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - - return [x, weight] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - x: torch.Tensor, - weight: torch.Tensor, - ) -> torch.Tensor: - all_gather = torch.ops.vllm.all_gather.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name, - ) - - return torch.ops.aten.mm.default(all_gather, weight) - - def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( - x, - [weight], - gather_dim=0, - group_name=self.tp.device_group.group_name, - ) - return mm_outputs - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class ScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self) -> list[torch.Tensor]: - input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - mm_weight = ( - torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - .contiguous() - .transpose(0, 1) - ) - scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) - scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - return [input, mm_weight, scale_a, scale_b] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - input: torch.Tensor, - mat2: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - ) -> torch.Tensor: - scaled_mm = torch.ops.aten._scaled_mm.default( - input, - mat2=mat2, - scale_a=scale_a, - scale_b=scale_b, - bias=None, - scale_result=None, - out_dtype=self.dtype, - ) - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - scaled_mm, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name, - ) - return reduce_scatter - - def replacement( - input: torch.Tensor, - mat2: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - ) -> torch.Tensor: - # Calculate output shape: input @ mat2 with scatter_dim reduced - output_shape = [*input.shape[:-1], mat2.shape[1]] - scatter_dim = 0 - gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( - input, - mat2, - scale_a, - scale_b, - "avg", - scatter_dim, # orig_scatter_dim - scatter_dim, # scatter_dim_after_maybe_reshape - self.tp.device_group.group_name, - output_shape, - None, # bias - None, # result_scale - self.dtype, # out_dtype - False, # use_fast_accum - ) - - return gemm_rs - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class AllGatherScaledMMPattern(BasePattern): - def get_inputs(self) -> list[torch.Tensor]: - x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = ( - torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - .contiguous() - .transpose(0, 1) - ) - - s1 = x.shape[0] * self.tp_size - - scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) - scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - - return [x, weight, scale_a, scale_b] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - x: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - ) -> torch.Tensor: - all_gather = torch.ops.vllm.all_gather.default( - x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name - ) - - return torch.ops.aten._scaled_mm.default( - all_gather, - mat2=weight, - scale_a=scale_a, - scale_b=scale_b, - bias=None, - scale_result=None, - out_dtype=self.dtype, - ) - - def replacement( - x: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - ) -> torch.Tensor: - ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa - x, - [weight], - scale_a, - [scale_b], - gather_dim=0, - biases=[None], - result_scales=[None], - out_dtypes=[self.dtype], - use_fast_accum=[False], - group_name=self.tp.device_group.group_name, - ) - return mm_outputs - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class CutlassScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self) -> list[torch.Tensor]: - input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - mm_weight = ( - torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - .contiguous() - .transpose(0, 1) - ) - scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) - scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - - cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype) - return [input, mm_weight, scale_a, scale_b, cutlass_mm_output] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - input: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - cutlass_mm_output: torch.Tensor, - ) -> torch.Tensor: - cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.cutlass_scaled_mm.default, - out=cutlass_mm_output, - a=input, - b=weight, - a_scales=scale_a, - b_scales=scale_b, - bias=None, - ) - - reduce_scatter = torch.ops.vllm.reduce_scatter.default( - cutlass_scaled_mm[1], - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name, - ) - return reduce_scatter - - def replacement( - input: torch.Tensor, - mat2: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - cutlass_mm_output: torch.Tensor, - ) -> torch.Tensor: - # Calculate output shape: input @ mat2 with scatter_dim reduced - output_shape = [*input.shape[:-1], mat2.shape[1]] - scatter_dim = 0 - gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( - input, - mat2, - scale_a, - scale_b, - "avg", - scatter_dim, # orig_scatter_dim - scatter_dim, # scatter_dim_after_maybe_reshape - self.tp.device_group.group_name, - output_shape, - None, # bias - None, # result_scale - self.dtype, # out_dtype - False, # use_fast_accum - ) - - return gemm_rs - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class AllGatherCutlassScaledMMPattern(BasePattern): - def get_inputs(self) -> list[torch.Tensor]: - x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = ( - torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - .contiguous() - .transpose(0, 1) - ) - - s1 = x.shape[0] * self.tp_size - - scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) - scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - - s2 = weight.shape[1] - output = torch.empty([s1, s2], device=self.device, dtype=self.dtype) - - return [x, weight, scale_a, scale_b, output] - - def register(self, pm_pass: PatternMatcherPass) -> None: - def pattern( - x: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - output: torch.Tensor, - ) -> torch.Tensor: - all_gather = torch.ops.vllm.all_gather.default( - x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name - ) - - cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( - torch.ops._C.cutlass_scaled_mm.default, - out=output, - a=all_gather, - b=weight, - a_scales=scale_a, - b_scales=scale_b, - bias=None, - ) - return cutlass_scaled_mm[1] - - def replacement( - x: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - output: torch.Tensor, - ) -> torch.Tensor: - ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa - x, - [weight], - scale_a, - [scale_b], - gather_dim=0, - biases=[None], - result_scales=[None], - out_dtypes=[self.dtype], - use_fast_accum=[False], - group_name=self.tp.device_group.group_name, - ) - return mm_outputs - - pm.register_replacement( - pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass - ) - - -class AsyncTPPass(VllmPatternMatcherPass): - @enable_fake_mode - def __init__(self, config: VllmConfig) -> None: - super().__init__(config) - - # Enable symmetric memory for the TP process group - enable_symm_mem_for_group(get_tp_group().device_group.group_name) - self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="async_tp_pass" - ) - GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) - - AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) - - # These fusions are enabled only for bfloat16 models because - # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling - # only supports bfloat16 as the output dtype. - if self.model_dtype == torch.bfloat16: - ScaledMMReduceScatterPattern(self.model_dtype, self.device).register( - self.patterns - ) - AllGatherScaledMMPattern(self.model_dtype, self.device).register( - self.patterns - ) - - CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register( - self.patterns - ) - AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register( - self.patterns - ) - - self.dump_patterns(config, self.patterns) - - def is_applicable_for_range(self, compile_range: Range) -> bool: - # This pass is applied on top of the sequence parallelism pass. - # It inherits the same applicability condition as `SequenceParallelismPass`. - # See `SequenceParallelismPass.is_applicable` for more details. - if ( - not self.compilation_config.splitting_ops - or self.compilation_config.use_inductor_graph_partition - ): - return True - tp_size = get_tensor_model_parallel_world_size() - return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0) - - @VllmInductorPass.time_and_log - def __call__(self, graph: fx.Graph) -> None: - self.matched_count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", self.matched_count) - - # Max size of the input tensor per world size per device capability # to use flashinfer fused allreduce FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = { @@ -623,6 +224,15 @@ class FlashInferFusedAllReduceParams: } +# TODO(luka): unify +class BasePattern: + def __init__(self, dtype: torch.dtype, device: str | None) -> None: + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + class AllReduceRMSNormPattern(BasePattern): """ This pattern replaces the allreduce + rms norm (without residual) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/passes/fusion/attn_quant_fusion.py similarity index 98% rename from vllm/compilation/fusion_attn.py rename to vllm/compilation/passes/fusion/attn_quant_fusion.py index 0dc4b1489..a104aab6c 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/passes/fusion/attn_quant_fusion.py @@ -22,11 +22,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 -from .fx_utils import is_func -from .inductor_pass import enable_fake_mode +from ..fx_utils import is_func +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import MatcherQuantFP8 -from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 logger = init_logger(__name__) P = ParamSpec("P") diff --git a/vllm/compilation/passes/fusion/collective_fusion.py b/vllm/compilation/passes/fusion/collective_fusion.py new file mode 100644 index 000000000..55a5a2e5d --- /dev/null +++ b/vllm/compilation/passes/fusion/collective_fusion.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch.distributed._symmetric_memory import enable_symm_mem_for_group + +from vllm.config import VllmConfig +from vllm.config.utils import Range +from vllm.distributed import get_tp_group +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass + +FP8_DTYPE = current_platform.fp8_dtype() + +logger = init_logger(__name__) + + +class BasePattern: + def __init__(self, dtype: torch.dtype, device: str | None) -> None: + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + +class GEMMReduceScatterPattern(BasePattern): + def get_inputs(self) -> list[torch.Tensor]: + mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [mul, mm_weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: + mm = torch.ops.aten.mm.default(mul, mm_weight) + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + return reduce_scatter + + def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor: + gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( + mul, + mm_weight, + "avg", + scatter_dim=0, + group_name=self.tp.device_group.group_name, + ) + + return gemm_rs + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AllGatherGEMMPattern(BasePattern): + def get_inputs(self) -> list[torch.Tensor]: + x = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + return [x, weight] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + + return torch.ops.aten.mm.default(all_gather, weight) + + def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( + x, + [weight], + gather_dim=0, + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class ScaledMMReduceScatterPattern(BasePattern): + def get_inputs(self) -> list[torch.Tensor]: + input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + mm_weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) + scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + return [input, mm_weight, scale_a, scale_b] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + scaled_mm = torch.ops.aten._scaled_mm.default( + input, + mat2=mat2, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype, + ) + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + scaled_mm, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + return reduce_scatter + + def replacement( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum + ) + + return gemm_rs + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AllGatherScaledMMPattern(BasePattern): + def get_inputs(self) -> list[torch.Tensor]: + x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) + + s1 = x.shape[0] * self.tp_size + + scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + return [x, weight, scale_a, scale_b] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.all_gather.default( + x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name + ) + + return torch.ops.aten._scaled_mm.default( + all_gather, + mat2=weight, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype, + ) + + def replacement( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa + x, + [weight], + scale_a, + [scale_b], + gather_dim=0, + biases=[None], + result_scales=[None], + out_dtypes=[self.dtype], + use_fast_accum=[False], + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class CutlassScaledMMReduceScatterPattern(BasePattern): + def get_inputs(self) -> list[torch.Tensor]: + input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + mm_weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) + scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype) + return [input, mm_weight, scale_a, scale_b, cutlass_mm_output] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor, + ) -> torch.Tensor: + cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.cutlass_scaled_mm.default, + out=cutlass_mm_output, + a=input, + b=weight, + a_scales=scale_a, + b_scales=scale_b, + bias=None, + ) + + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + cutlass_scaled_mm[1], + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + return reduce_scatter + + def replacement( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor, + ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( + input, + mat2, + scale_a, + scale_b, + "avg", + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum + ) + + return gemm_rs + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AllGatherCutlassScaledMMPattern(BasePattern): + def get_inputs(self) -> list[torch.Tensor]: + x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) + + s1 = x.shape[0] * self.tp_size + + scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32) + scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) + + s2 = weight.shape[1] + output = torch.empty([s1, s2], device=self.device, dtype=self.dtype) + + return [x, weight, scale_a, scale_b, output] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.vllm.all_gather.default( + x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name + ) + + cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( + torch.ops._C.cutlass_scaled_mm.default, + out=output, + a=all_gather, + b=weight, + a_scales=scale_a, + b_scales=scale_b, + bias=None, + ) + return cutlass_scaled_mm[1] + + def replacement( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa + x, + [weight], + scale_a, + [scale_b], + gather_dim=0, + biases=[None], + result_scales=[None], + out_dtypes=[self.dtype], + use_fast_accum=[False], + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class AsyncTPPass(VllmPatternMatcherPass): + @enable_fake_mode + def __init__(self, config: VllmConfig) -> None: + super().__init__(config) + + # Enable symmetric memory for the TP process group + enable_symm_mem_for_group(get_tp_group().device_group.group_name) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="async_tp_pass" + ) + GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) + + AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) + + # These fusions are enabled only for bfloat16 models because + # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling + # only supports bfloat16 as the output dtype. + if self.model_dtype == torch.bfloat16: + ScaledMMReduceScatterPattern(self.model_dtype, self.device).register( + self.patterns + ) + AllGatherScaledMMPattern(self.model_dtype, self.device).register( + self.patterns + ) + + CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register( + self.patterns + ) + AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register( + self.patterns + ) + + self.dump_patterns(config, self.patterns) + + def is_applicable_for_range(self, compile_range: Range) -> bool: + # This pass is applied on top of the sequence parallelism pass. + # It inherits the same applicability condition as `SequenceParallelismPass`. + # See `SequenceParallelismPass.is_applicable` for more details. + if ( + not self.compilation_config.splitting_ops + or self.compilation_config.use_inductor_graph_partition + ): + return True + tp_size = get_tensor_model_parallel_world_size() + return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py similarity index 100% rename from vllm/compilation/matcher_utils.py rename to vllm/compilation/passes/fusion/matcher_utils.py diff --git a/vllm/compilation/qk_norm_rope_fusion.py b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py similarity index 97% rename from vllm/compilation/qk_norm_rope_fusion.py rename to vllm/compilation/passes/fusion/qk_norm_rope_fusion.py index 3ddd2b87f..dd1f8245e 100644 --- a/vllm/compilation/qk_norm_rope_fusion.py +++ b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py @@ -15,10 +15,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from .fusion import empty_bf16, empty_fp32, empty_i64 -from .inductor_pass import enable_fake_mode +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding -from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64 logger = init_logger(__name__) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/passes/fusion/rms_quant_fusion.py similarity index 99% rename from vllm/compilation/fusion.py rename to vllm/compilation/passes/fusion/rms_quant_fusion.py index 667828cc6..eac9fea28 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/passes/fusion/rms_quant_fusion.py @@ -25,13 +25,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform -from .inductor_pass import enable_fake_mode +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import ( MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm, ) -from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() diff --git a/vllm/compilation/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py similarity index 98% rename from vllm/compilation/rocm_aiter_fusion.py rename to vllm/compilation/passes/fusion/rocm_aiter_fusion.py index bfbb2b783..8165c18f0 100644 --- a/vllm/compilation/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -9,7 +9,6 @@ from torch._ops import OpOverload import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops -from vllm.compilation.activation_quant_fusion import ActivationQuantPattern from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -19,17 +18,18 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform -from .fusion import ( - FusedRMSQuantKey, -) -from .inductor_pass import enable_fake_mode +from ..activation_quant_fusion import ActivationQuantPattern +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import ( MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm, MatcherSiluAndMul, ) -from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from .rms_quant_fusion import ( + FusedRMSQuantKey, +) logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py similarity index 98% rename from vllm/compilation/sequence_parallelism.py rename to vllm/compilation/passes/fusion/sequence_parallelism.py index dda653c5f..5fb932d72 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/passes/fusion/sequence_parallelism.py @@ -20,10 +20,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform -from .inductor_pass import enable_fake_mode +from ..inductor_pass import enable_fake_mode +from ..utility.noop_elimination import NoOpEliminationPass +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm -from .noop_elimination import NoOpEliminationPass -from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/passes/fx_utils.py similarity index 100% rename from vllm/compilation/fx_utils.py rename to vllm/compilation/passes/fx_utils.py diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/passes/inductor_pass.py similarity index 100% rename from vllm/compilation/inductor_pass.py rename to vllm/compilation/passes/inductor_pass.py diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/passes/pass_manager.py similarity index 89% rename from vllm/compilation/pass_manager.py rename to vllm/compilation/passes/pass_manager.py index e0565ccb2..2fd74fcd4 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -8,38 +8,39 @@ from torch import fx as fx from vllm import envs from vllm._aiter_ops import rocm_aiter_ops +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.system_utils import set_env_var -from .post_cleanup import PostCleanupPass from .vllm_inductor_pass import VllmInductorPass if rocm_aiter_ops.is_enabled(): - from vllm.compilation.rocm_aiter_fusion import ( + from .fusion.rocm_aiter_fusion import ( RocmAiterRMSNormQuantFusionPass, RocmAiterSiluMulFp8GroupQuantFusionPass, RocmAiterTritonAddRMSNormPadFusionPass, ) if current_platform.is_cuda_alike(): - from .activation_quant_fusion import ActivationQuantFusionPass - from .fusion import RMSNormQuantFusionPass - from .fusion_attn import AttnFusionPass - from .qk_norm_rope_fusion import QKNormRoPEFusionPass - from .sequence_parallelism import SequenceParallelismPass + from .fusion.act_quant_fusion import ActivationQuantFusionPass + from .fusion.attn_quant_fusion import AttnFusionPass + from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass + from .fusion.rms_quant_fusion import RMSNormQuantFusionPass + from .fusion.sequence_parallelism import SequenceParallelismPass if current_platform.is_cuda(): - from .collective_fusion import AllReduceFusionPass, AsyncTPPass + from .fusion.allreduce_rms_fusion import AllReduceFusionPass + from .fusion.collective_fusion import AsyncTPPass -from .fix_functionalization import FixFunctionalizationPass from .inductor_pass import ( CustomGraphPass, InductorPass, get_pass_context, ) -from .noop_elimination import NoOpEliminationPass +from .utility.fix_functionalization import FixFunctionalizationPass +from .utility.noop_elimination import NoOpEliminationPass logger = init_logger(__name__) diff --git a/vllm/compilation/passes/utility/__init__.py b/vllm/compilation/passes/utility/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py similarity index 99% rename from vllm/compilation/fix_functionalization.py rename to vllm/compilation/passes/utility/fix_functionalization.py index ce37968c9..55126a757 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -10,8 +10,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger from vllm.platforms import current_platform -from .fx_utils import is_func -from .vllm_inductor_pass import VllmInductorPass +from ..fx_utils import is_func +from ..vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/passes/utility/noop_elimination.py similarity index 98% rename from vllm/compilation/noop_elimination.py rename to vllm/compilation/passes/utility/noop_elimination.py index 9af904b45..5f7d47ad6 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/passes/utility/noop_elimination.py @@ -9,8 +9,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true from vllm.logger import init_logger -from .fx_utils import is_func -from .vllm_inductor_pass import VllmInductorPass +from ..fx_utils import is_func +from ..vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) diff --git a/vllm/compilation/post_cleanup.py b/vllm/compilation/passes/utility/post_cleanup.py similarity index 91% rename from vllm/compilation/post_cleanup.py rename to vllm/compilation/passes/utility/post_cleanup.py index 551175168..d4ecd4d65 100644 --- a/vllm/compilation/post_cleanup.py +++ b/vllm/compilation/passes/utility/post_cleanup.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from torch import fx -from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from ..vllm_inductor_pass import VllmInductorPass class PostCleanupPass(VllmInductorPass): diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/passes/vllm_inductor_pass.py similarity index 100% rename from vllm/compilation/vllm_inductor_pass.py rename to vllm/compilation/passes/vllm_inductor_pass.py diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 556254a65..2b4ce27a3 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal from pydantic import Field, TypeAdapter, field_validator import vllm.envs as envs -from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.compilation.passes.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import ( Range, config, @@ -170,7 +170,9 @@ class PassConfig: @staticmethod def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]: - from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB + from vllm.compilation.passes.fusion.allreduce_rms_fusion import ( + FI_ALLREDUCE_FUSION_MAX_SIZE_MB, + ) from vllm.platforms import current_platform if not current_platform.is_cuda(): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c3b189e01..d50d2f69c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -191,7 +191,7 @@ class Platform: Get the pass manager class for this platform. It will be registered as a custom pass under the current_platform.pass_key. """ - return "vllm.compilation.pass_manager.PostGradPassManager" + return "vllm.compilation.passes.pass_manager.PostGradPassManager" @classmethod def get_compile_backend(cls) -> str: