[torch.compile] Reorganize vllm/compilation and tests/compile (0/N for vLLM IR) (#33731)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: ProExpertProg <luka.govedic@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -551,7 +551,7 @@ steps:
|
||||
- label: LoRA Test %N # 20min each
|
||||
timeout_in_minutes: 30
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
@@ -647,7 +647,7 @@ steps:
|
||||
- label: Kernels Attention Test %N # 23min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- csrc/attention/
|
||||
@@ -662,7 +662,7 @@ steps:
|
||||
- label: Kernels Quantization Test %N # 64min
|
||||
timeout_in_minutes: 90
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
@@ -675,7 +675,7 @@ steps:
|
||||
- label: Kernels MoE Test %N # 40min
|
||||
timeout_in_minutes: 60
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/cutlass_w8a8/moe/
|
||||
@@ -753,7 +753,7 @@ steps:
|
||||
- label: Benchmarks # 11min
|
||||
timeout_in_minutes: 20
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
source_file_dependencies:
|
||||
@@ -764,7 +764,7 @@ steps:
|
||||
- label: Benchmarks CLI Test # 7min
|
||||
timeout_in_minutes: 20
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@@ -838,7 +838,7 @@ steps:
|
||||
- label: Basic Models Tests (Extra Initialization) %N
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
@@ -900,7 +900,7 @@ steps:
|
||||
- label: Language Models Tests (Extra Standard) %N
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
@@ -921,7 +921,7 @@ steps:
|
||||
- label: Language Models Tests (Hybrid) %N
|
||||
timeout_in_minutes: 75
|
||||
mirror_hardwares: [amdexperimental]
|
||||
agent_pool: mi325_1
|
||||
agent_pool: mi325_8
|
||||
# grade: Blocking
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
@@ -1190,16 +1190,16 @@ steps:
|
||||
- vllm/model_executor/layers/layernorm.py
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- tests/compile/test_fusion_attn.py
|
||||
- tests/compile/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/passes/test_fusion_attn.py
|
||||
- tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s tests/compile/passes/test_fusion_attn.py
|
||||
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
|
||||
# # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||
# # Wrap with quotes to escape yaml
|
||||
@@ -1556,15 +1556,15 @@ steps:
|
||||
working_dir: "/vllm-workspace/"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py
|
||||
- pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
#- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
||||
# - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
# Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293
|
||||
# in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated.
|
||||
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- HIP_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=allgather_reducescatter --disable-nccl-for-dp-synchronization
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
@@ -519,6 +519,7 @@ steps:
|
||||
# However, find does not normally propagate error codes, so we combine it with xargs
|
||||
# (using -0 for proper path handling)
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
|
||||
- pytest -s -v compile/passes --ignore compile/passes/distributed
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
@@ -1080,14 +1081,14 @@ steps:
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- tests/compile/test_fusion_attn.py
|
||||
- tests/compile/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
# # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||
# # Wrap with quotes to escape yaml
|
||||
# - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
|
||||
@@ -1421,8 +1422,8 @@ steps:
|
||||
commands:
|
||||
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
|
||||
# Run sequence parallel tests
|
||||
- pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
|
||||
|
||||
- label: Distributed Tests (H100) # optional
|
||||
gpu: h100
|
||||
@@ -1430,7 +1431,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_async_tp.py
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
@@ -2,7 +2,7 @@ group: Compile
|
||||
depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: Sequence Parallel Tests (2 GPUs)
|
||||
- label: Sequence Parallel Correctness Tests (2 GPUs)
|
||||
timeout_in_minutes: 50
|
||||
working_dir: "/vllm-workspace/"
|
||||
num_devices: 2
|
||||
@@ -11,12 +11,12 @@ steps:
|
||||
- vllm/compilation/
|
||||
- vllm/v1/worker/
|
||||
- vllm/v1/cudagraph_dispatcher.py
|
||||
- tests/distributed/test_sequence_parallel.py
|
||||
- tests/compile/correctness_e2e/test_sequence_parallel.py
|
||||
commands:
|
||||
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
|
||||
- pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py
|
||||
|
||||
- label: Sequence Parallel Tests (2xH100)
|
||||
- label: Sequence Parallel Correctness Tests (2xH100)
|
||||
timeout_in_minutes: 50
|
||||
working_dir: "/vllm-workspace/"
|
||||
device: h100
|
||||
@@ -24,24 +24,30 @@ steps:
|
||||
num_devices: 2
|
||||
commands:
|
||||
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
|
||||
- pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||
- pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py
|
||||
|
||||
- label: AsyncTP Correctness Tests (2xH100)
|
||||
timeout_in_minutes: 50
|
||||
working_dir: "/vllm-workspace/"
|
||||
device: h100
|
||||
optional: true
|
||||
num_devices: 2
|
||||
commands:
|
||||
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
|
||||
- pytest -v -s tests/compile/correctness_e2e/test_async_tp.py
|
||||
|
||||
- label: Distributed Compile Unit Tests (2xH100)
|
||||
timeout_in_minutes: 40
|
||||
timeout_in_minutes: 20
|
||||
working_dir: "/vllm-workspace/"
|
||||
device: h100
|
||||
num_devices: 2
|
||||
source_file_dependencies:
|
||||
- vllm/compilation/
|
||||
- vllm/model_executor/layers
|
||||
- tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/distributed/test_sequence_parallelism.py
|
||||
- tests/compile/distributed/test_async_tp.py
|
||||
- tests/compile/passes/distributed/
|
||||
commands:
|
||||
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
|
||||
- pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/distributed/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/distributed/test_async_tp.py
|
||||
- pytest -s -v tests/compile/passes/distributed
|
||||
|
||||
- label: Fusion and Compile Unit Tests (B200)
|
||||
timeout_in_minutes: 20
|
||||
@@ -55,17 +61,17 @@ steps:
|
||||
- vllm/model_executor/layers/attention/attention.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
|
||||
- tests/compile/test_fusion_attn.py
|
||||
- tests/compile/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/passes/test_fusion_attn.py
|
||||
- tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
commands:
|
||||
# b200 runners are limited, so we limit the tests to the minimum set only supported on Blackwell
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py -k FLASHINFER
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s tests/compile/passes/test_fusion_attn.py -k FLASHINFER
|
||||
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_devices=2 is not set
|
||||
- pytest -v -s tests/compile/distributed/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||
# TODO(luka) move to H100 once pass tests run on H100
|
||||
- pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
@@ -3,7 +3,7 @@ depends_on:
|
||||
- image-build
|
||||
steps:
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
timeout_in_minutes: 30
|
||||
timeout_in_minutes: 10
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
@@ -17,6 +17,14 @@ steps:
|
||||
# (using -0 for proper path handling)
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -print0 | xargs -0 -n1 -I{} pytest -s -v '{}'"
|
||||
|
||||
- label: PyTorch Compilation Passes Unit Tests
|
||||
timeout_in_minutes: 20
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile/passes
|
||||
commands:
|
||||
- pytest -s -v compile/passes --ignore compile/passes/distributed
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test
|
||||
timeout_in_minutes: 35
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -11,10 +11,10 @@ from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.compilation.pass_manager import with_pattern_match_debug
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.inductor_pass import InductorPass
|
||||
from vllm.compilation.passes.pass_manager import with_pattern_match_debug
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
|
||||
79
tests/compile/correctness_e2e/test_async_tp.py
Normal file
79
tests/compile/correctness_e2e/test_async_tp.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from tests.utils import (
|
||||
compare_two_settings,
|
||||
create_new_process_for_each_test,
|
||||
)
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
def test_async_tp_pass_correctness(
|
||||
model_id: str,
|
||||
tp_size: int,
|
||||
async_tp_enabled: bool,
|
||||
distributed_backend: str,
|
||||
eager_mode: bool,
|
||||
num_gpus_available: int,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
pp_size = 1
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
|
||||
common_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
|
||||
compilation_config = {
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [2, 4, 8],
|
||||
"splitting_ops": [],
|
||||
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
|
||||
}
|
||||
|
||||
async_tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
"--compilation_config",
|
||||
json.dumps(compilation_config),
|
||||
]
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
compare_two_settings(model_id, async_tp_args, tp_args, method="generate")
|
||||
@@ -21,8 +21,8 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
from ...models.registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
||||
logger = init_logger("test_sequence_parallel")
|
||||
|
||||
@@ -82,19 +82,17 @@ INDUCTOR_GRAPH_PARTITION = [
|
||||
]
|
||||
|
||||
FUSION_LOG_PATTERNS: dict[str, re.Pattern] = {
|
||||
"rms_quant_fusion": re.compile(
|
||||
r"\[(?:compilation/)?fusion.py:\d+] Replaced (\d+) patterns"
|
||||
),
|
||||
"act_quant_fusion": re.compile(
|
||||
r"activation_quant_fusion.py:\d+] Replaced (\d+) patterns"
|
||||
),
|
||||
"rms_quant_fusion": re.compile(r"rms_quant_fusion.py:\d+] Replaced (\d+) patterns"),
|
||||
"act_quant_fusion": re.compile(r"act_quant_fusion.py:\d+] Replaced (\d+) patterns"),
|
||||
"norm_rope_fusion": re.compile(
|
||||
r"qk_norm_rope_fusion.py:\d+] Fused QK Norm\+RoPE on (\d+) sites"
|
||||
),
|
||||
"attn_quant_fusion": re.compile(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes"
|
||||
r"attn_quant_fusion.py:\d+] Fused quant onto (\d+) attention nodes"
|
||||
),
|
||||
"ar_rms_fusion": re.compile(
|
||||
r"allreduce_rms_fusion.py:\d+] Replaced (\d+) patterns"
|
||||
),
|
||||
"ar_rms_fusion": re.compile(r"collective_fusion.py:\d+] Replaced (\d+) patterns"),
|
||||
"sequence_parallel": re.compile(
|
||||
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns"
|
||||
),
|
||||
|
||||
0
tests/compile/passes/__init__.py
Normal file
0
tests/compile/passes/__init__.py
Normal file
0
tests/compile/passes/distributed/__init__.py
Normal file
0
tests/compile/passes/distributed/__init__.py
Normal file
@@ -1,16 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AsyncTPPass
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import (
|
||||
multi_gpu_test,
|
||||
)
|
||||
from vllm.compilation.passes.fusion.collective_fusion import AsyncTPPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
@@ -29,14 +31,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
from ...models.registry import HF_EXAMPLE_MODELS
|
||||
from ...utils import (
|
||||
compare_two_settings,
|
||||
create_new_process_for_each_test,
|
||||
multi_gpu_test,
|
||||
)
|
||||
from ..backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
prompts = [
|
||||
@@ -377,67 +371,3 @@ def async_tp_pass_on_test_model(
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"],
|
||||
)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
def test_async_tp_pass_correctness(
|
||||
model_id: str,
|
||||
tp_size: int,
|
||||
async_tp_enabled: bool,
|
||||
distributed_backend: str,
|
||||
eager_mode: bool,
|
||||
num_gpus_available: int,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
pp_size = 1
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
|
||||
common_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
|
||||
compilation_config = {
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"compile_sizes": [2, 4, 8],
|
||||
"splitting_ops": [],
|
||||
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
|
||||
}
|
||||
|
||||
async_tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
"--compilation_config",
|
||||
json.dumps(compilation_config),
|
||||
]
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
compare_two_settings(model_id, async_tp_args, tp_args, method="generate")
|
||||
@@ -6,11 +6,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestFP8Layer, has_module_attribute, multi_gpu_test
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.passes.utility.fix_functionalization import (
|
||||
FixFunctionalizationPass,
|
||||
)
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
@@ -33,9 +37,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test
|
||||
from ..backend import TestBackend
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
@@ -5,12 +5,14 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestFP8Layer, multi_gpu_test
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.compilation.passes.fx_utils import find_auto_fn
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CUDAGraphMode,
|
||||
@@ -34,9 +36,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
from ...utils import TestFP8Layer, multi_gpu_test
|
||||
from ..backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -5,12 +5,18 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm.compilation.passes.fusion.act_quant_fusion import (
|
||||
ActivationQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.passes.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.passes.utility.fix_functionalization import (
|
||||
FixFunctionalizationPass,
|
||||
)
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
@@ -26,9 +32,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
from .backend import TestBackend
|
||||
|
||||
TEST_FP8 = current_platform.supports_fp8()
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
@@ -6,9 +6,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.config
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
@@ -19,8 +20,6 @@ from vllm.config import (
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -95,7 +94,7 @@ def test_fuse_act_padding(
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
@@ -7,12 +7,18 @@ import torch
|
||||
|
||||
import vllm.config
|
||||
import vllm.plugins
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestBlockFP8Layer, TestFP8Layer
|
||||
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import (
|
||||
FUSED_OPS,
|
||||
FusedRMSQuantKey,
|
||||
RMSNormQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
@@ -51,9 +57,6 @@ from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
|
||||
from ..utils import TestBlockFP8Layer, TestFP8Layer
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
@@ -223,7 +226,7 @@ class TestModel(torch.nn.Module):
|
||||
if self.use_aiter_fusion:
|
||||
if self.group_shape.is_per_group():
|
||||
# Blockwise aiter fusion
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
)
|
||||
@@ -234,7 +237,7 @@ class TestModel(torch.nn.Module):
|
||||
]
|
||||
else:
|
||||
# Per-token aiter fusion
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
)
|
||||
@@ -410,7 +413,9 @@ def test_aiter_fusion_rmsnorm_quant(
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormQuantFusionPass
|
||||
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormQuantFusionPass,
|
||||
)
|
||||
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
@@ -6,14 +6,14 @@ import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.utils import flat_product
|
||||
from tests.utils import TestFP8Layer, flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
@@ -38,8 +38,6 @@ from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
@@ -5,11 +5,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
# Important edge case is when `num_tokens == buffer_size`
|
||||
@@ -5,12 +5,12 @@ import copy
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.compilation.inductor_pass import (
|
||||
from vllm.compilation.passes.inductor_pass import (
|
||||
CallableInductorPass,
|
||||
InductorPass,
|
||||
pass_context,
|
||||
)
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
from vllm.compilation.passes.pass_manager import PostGradPassManager
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
|
||||
@@ -5,13 +5,17 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.qk_norm_rope_fusion import (
|
||||
from vllm.compilation.passes.fusion.matcher_utils import (
|
||||
FLASHINFER_ROTARY_OP,
|
||||
RMS_OP,
|
||||
ROTARY_OP,
|
||||
)
|
||||
from vllm.compilation.passes.fusion.qk_norm_rope_fusion import (
|
||||
FUSED_QK_ROPE_OP,
|
||||
QKNormRoPEFusionPass,
|
||||
)
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
@@ -6,17 +6,19 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm._aiter_ops import IS_AITER_FOUND
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
from vllm.compilation.passes.fusion.act_quant_fusion import (
|
||||
FUSED_OPS,
|
||||
SILU_MUL_OP,
|
||||
ActivationQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.fusion.rms_quant_fusion import QUANT_OPS
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
@@ -48,9 +50,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
@@ -100,7 +99,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||
super().__init__()
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
from vllm.compilation.passes.fusion.act_quant_fusion import (
|
||||
silu_and_mul_nvfp4_quant_supported,
|
||||
)
|
||||
|
||||
@@ -239,7 +238,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
with set_current_vllm_config(config):
|
||||
fusion_passes = [ActivationQuantFusionPass(config)]
|
||||
if IS_AITER_FOUND:
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from torch import nn
|
||||
import tests.compile.silly_attention # noqa
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.inductor_pass import (
|
||||
from vllm.compilation.passes.inductor_pass import (
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,9 @@ import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.passes.utility.fix_functionalization import (
|
||||
FixFunctionalizationPass,
|
||||
)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CUDAGraphMode,
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from vllm.compilation.backends import split_graph
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
|
||||
@@ -22,11 +22,6 @@ from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
from vllm.compilation.partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.config.utils import Range, hash_factors
|
||||
@@ -44,8 +39,12 @@ from .compiler_interface import (
|
||||
is_compile_cache_enabled,
|
||||
)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .pass_manager import PostGradPassManager
|
||||
from .partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
should_split,
|
||||
)
|
||||
from .passes.inductor_pass import InductorPass, pass_context
|
||||
from .passes.pass_manager import PostGradPassManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
@@ -24,12 +23,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
flashinfer_comm: ModuleType | None = None
|
||||
if find_spec("flashinfer"):
|
||||
try:
|
||||
@@ -45,406 +46,6 @@ logger = init_logger(__name__)
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(
|
||||
input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(
|
||||
all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass"
|
||||
)
|
||||
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
|
||||
@@ -623,6 +224,15 @@ class FlashInferFusedAllReduceParams:
|
||||
}
|
||||
|
||||
|
||||
# TODO(luka): unify
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
@@ -22,11 +22,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .fx_utils import is_func
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..fx_utils import is_func
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
P = ParamSpec("P")
|
||||
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(
|
||||
input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(
|
||||
all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass"
|
||||
)
|
||||
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
@@ -15,10 +15,10 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from .fusion import empty_bf16, empty_fp32, empty_i64
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -25,13 +25,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
)
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -9,7 +9,6 @@ from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -19,17 +18,18 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..activation_quant_fusion import ActivationQuantPattern
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
MatcherSiluAndMul,
|
||||
)
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .rms_quant_fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -20,10 +20,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -8,38 +8,39 @@ from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
from .fusion.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
from .fusion_attn import AttnFusionPass
|
||||
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .sequence_parallelism import SequenceParallelismPass
|
||||
from .fusion.act_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion.attn_quant_fusion import AttnFusionPass
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .collective_fusion import AllReduceFusionPass, AsyncTPPass
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
from .fusion.collective_fusion import AsyncTPPass
|
||||
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
from .utility.fix_functionalization import FixFunctionalizationPass
|
||||
from .utility.noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
0
vllm/compilation/passes/utility/__init__.py
Normal file
0
vllm/compilation/passes/utility/__init__.py
Normal file
@@ -10,8 +10,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -9,8 +9,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
@@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
from pydantic import Field, TypeAdapter, field_validator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.compilation.passes.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.config.utils import (
|
||||
Range,
|
||||
config,
|
||||
@@ -170,7 +170,9 @@ class PassConfig:
|
||||
|
||||
@staticmethod
|
||||
def default_fi_allreduce_fusion_max_size_mb() -> dict[int, float]:
|
||||
from vllm.compilation.collective_fusion import FI_ALLREDUCE_FUSION_MAX_SIZE_MB
|
||||
from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
|
||||
@@ -191,7 +191,7 @@ class Platform:
|
||||
Get the pass manager class for this platform.
|
||||
It will be registered as a custom pass under the current_platform.pass_key.
|
||||
"""
|
||||
return "vllm.compilation.pass_manager.PostGradPassManager"
|
||||
return "vllm.compilation.passes.pass_manager.PostGradPassManager"
|
||||
|
||||
@classmethod
|
||||
def get_compile_backend(cls) -> str:
|
||||
|
||||
Reference in New Issue
Block a user