diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 1d5b9e28a..218b52004 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -5,6 +5,17 @@ This guide covers optimization strategies and performance tuning for vLLM V1. !!! tip Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory. +## Optimization Levels + +vLLM provides 4 optimization levels (`-O0`, `-O1`, `-O2`, `-O3`) that allow users to trade off startup time for performance: + +- `-O0`: No optimizations. Fastest startup time, but lowest performance. +- `-O1`: Fast optimization. Simple compilation and fast fusions, and PIECEWISE cudagraphs. +- `-O2`: Default optimization. Additional compilation ranges, additional fusions, FULL_AND_PIECEWISE cudagraphs. +- `-O3`: Aggressive optimization. Currently equal to `-O2`, but may include additional time-consuming or experimental optimizations in the future. + +For more information, see the [optimization level documentation](../design/optimization_levels.md). + ## Preemption Due to the autoregressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. diff --git a/docs/design/fusions.md b/docs/design/fusions.md new file mode 100644 index 000000000..352c87533 --- /dev/null +++ b/docs/design/fusions.md @@ -0,0 +1,339 @@ +# Fusion torch.compile passes + +vLLM applies a set of kernel/operator fusions at compile time (via custom [`torch.compile`](torch_compile.md) Inductor passes) +to separate optimizations from model definitions and avoid breaking layer abstractions in model code. +These fusions are controlled by fields in [`PassConfig`][vllm.config.compilation.PassConfig] and are automatically enabled +at appropriate [optimization levels](optimization_levels.md). + +## Quick Reference + +The table below maps each fusion to its controlling flag/config knob, the +operations it fuses, what level enables it by default, and an indicative speedup. +The Fullgraph column indicates whether the fusion requires the entire model graph to be +visible (either via Inductor partition or `splitting_ops=[]`), +and the last column indicates whether the fusion activates for all `num_tokens` +or just on the low or high end. + +!!! info + Speedup depends heavily on the exact model, batch size, and hardware. + If tuning performance by hand, always benchmark your exact use-case with and without the fusion to verify the impact. + +| Fusion | `PassConfig` flag | Fused operations | Default at | E2E Speedup | Fullgraph | `num_tokens` | +|--------------------------------------------------------------------------------|------------------------------|------------------------------------------------|--------------------------------|--------------------|-----------|--------------| +| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low | +| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always | +| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O1 (ROCm/AITER only) | TBD | No | Low | +| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low | +| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High | +| [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms` | GEMM → reduce-scatter / all-gather → GEMM | Off by default | 7-10% | Yes | High | +| [RMSNorm + Quant](#rmsnorm--quantization-fuse_norm_quant) | `fuse_norm_quant` | RMSNorm (+residual add) → FP8/FP4 quant | O1 (conditional) | 1-4% | No | Always | +| [SiLU+Mul + Quant](#silumul--quantization-fuse_act_quant) | `fuse_act_quant` | SiLU+Mul activation → FP8/FP4 quant | O1 (conditional) | 1-4% | No | Always | +| [RMSNorm + Padding](#rmsnorm--padding-fuse_act_padding) | `fuse_act_padding` | Residual add + RMSNorm → padding | O1 (ROCm/AITER only) | TBD | No | Always | + +## Support Matrix + +The table below lists the quantization schemes supported by each fusion on each platform. +**—** means the fusion is not available on that platform. The latest and in-progress work is available in the tracking issue: +[#36066](https://github.com/vllm-project/vllm/issues/36066) + +| Fusion | SM100 (Blackwell) | SM90 (Hopper) | SM89 (Ada) | SM80 (Ampere) | ROCm | +|------------------------------|------------------------------------------|------------------------------------------|------------------------------------------|---------------|------------------------------------------| +| `fuse_allreduce_rms` | FP16/BF16, FP8 static, NVFP4 | FP16/BF16, FP8 static | — | — | — | +| `fuse_attn_quant`\* | FP8 static\*, NVFP4\* | FP8 static\* | FP8 static\* | — | FP8 static\* | +| `fuse_rope_kvcache` | — | — | — | — | FP16/BF16 | +| `enable_qk_norm_rope_fusion` | FP16/BF16 | FP16/BF16 | FP16/BF16† | FP16/BF16† | — | +| `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — | +| `fuse_gemm_comms` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — | +| `fuse_norm_quant` | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | — | FP8 static, FP8 per-token, FP8 per-group | +| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static | FP8 static | — | FP8 per-group | +| `fuse_act_padding` | — | — | — | — | FP16/BF16 | + +\* `fuse_attn_quant` support depends on the attention backend in use; not all backends support +fused quantization output. See the [`fuse_attn_quant` section](#attention--quantization-fuse_attn_quant) +for per-backend details. + +† `enable_sp` and `fuse_gemm_comms` are only autoconfigured for SM90 today; +other architectures support requires setting `PassConfig.sp_min_token_num` explicitly. +SM100 support also requires setting `VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel`. + +## Enabling / Disabling Fusions + +Fusions are exposed through `PassConfig`, which is nested inside `CompilationConfig`: + +```python +from vllm import LLM +from vllm.config import CompilationConfig, PassConfig + +llm = LLM( + model="...", + optimization_level=2, # Default optimization level + compilation_config=CompilationConfig( + pass_config=PassConfig( + fuse_norm_quant=True, + fuse_act_quant=True, + fuse_allreduce_rms=False, # disable a specific fusion + ) + ), +) +``` + +Fusions can also be enabled using command-line flags with any `vllm ...` command: + +```bash +# Enable O2 defaults, but turn off allreduce fusion +vllm serve meta-llama/Llama-3.1-8B-Instruct -O2 -cc.pass_config.fuse_allreduce_rms=False + +# The above is equivalent to the more verbose: +vllm serve meta-llama/Llama-3.1-8B-Instruct -O2 --compilation-config '{"pass_config": {"fuse_allreduce_rms": false}}' + +# Same syntax in other commands, e.g. vllm bench: +vllm bench latency --model=meta-llama/Llama-3.1-8B-Instruct -O2 -cc.pass_config.fuse_allreduce_rms=False +``` + +Fields set explicitly by the user always take precedence over optimization-level defaults. + +## Fusion Details + +### AllReduce + RMSNorm (`fuse_allreduce_rms`) + +!!! warning + TP+DP and TP+PP combinations are currently broken + ([#34458](https://github.com/vllm-project/vllm/issues/34458) and + [#35426](https://github.com/vllm-project/vllm/issues/35426)). + Only supported on NVIDIA Hopper (SM90) and Blackwell (SM100) with FlashInfer installed. + +**What it fuses.** Fuses the tensor-parallel all-reduce collective with the subsequent residual add, +RMSNorm, and optionally a quantization step into a single FlashInfer / TRT-LLM communication kernel. +This fusion is only profitable for small `num_tokens`, +so the fusion is only performed in the lower compiled range. + +Patterns covered: + +- `AllReduce → RMSNorm(+residual_add)`: CUDA sm90+ with FlashInfer +- `AllReduce → RMSNorm(+residual_add) → FP8 static quant`: CUDA sm90+ with FlashInfer +- `AllReduce → RMSNorm(+residual_add) → NVFP4 dynamic quant`: CUDA sm100+ with FlashInfer + +The maximum tensor size below which the fused kernel is used is hardware-dependent (64 MB for TP=2 +on SM90/SM100) and configurable via `PassConfig.fi_allreduce_fusion_max_size_mb`. + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/allreduce_rms_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/allreduce_rms_fusion.py) +- FlashInfer all-reduce: [`vllm/distributed/device_communicators/flashinfer_all_reduce.py`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/flashinfer_all_reduce.py) +- Benchmark: [`benchmarks/kernels/benchmark_fused_collective.py`](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_fused_collective.py) + +### Attention + Quantization (`fuse_attn_quant`) + +!!! info + `fuse_attn_quant` is currently not enabled at any optimization level by default and must be set + explicitly. It requires the full model graph to be visible (Inductor partition or `splitting_ops=[]`). + +**What it fuses.** Fuses the attention output quantization directly after the attention computation, +eliminating a full-precision memory round-trip of the attention output. Patterns covered: + +`Attention → FP8 static quant`: + +- `TRITON_ATTN`: CUDA, ROCm +- `FLASHINFER`: CUDA sm100+ with FlashInfer installed +- `ROCM_ATTN`: ROCm +- `ROCM_AITER_UNIFIED_ATTN`: ROCm with AITER + +`Attention → NVFP4 dynamic quant`: + +- `FLASHINFER`: CUDA sm100+ with FlashInfer installed + +Other attention backends do not support fused output quantization yet. + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/attn_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/attn_quant_fusion.py) +- Attention backends: [`vllm/v1/attention/backends/`](https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/) + +### RoPE + KV-Cache Update (`fuse_rope_kvcache`) + +!!! info + ROCm/AITER-only. Not available on NVIDIA CUDA or CPU. The fusion is only enabled for + `num_tokens ≤ 256` by default due to AITER fused kernel performance issues. + This threshold is configurable via `PassConfig.rope_kvcache_fusion_max_token_num`. + +**What it fuses.** Fuses the rotary positional embedding kernel with the KV-cache scatter/write into +a single kernel, avoiding separate reads and writes of the key and value tensors. + +Requires: AMD ROCm with AITER enabled, the `rotary_embedding` custom op active (automatic), +and the `kv_cache` update op visible in the graph: either by using Inductor graph partition +or removed from `splitting_ops`. +If these conditions are set, the fusion is enabled automatically for optimization level O1 and above. + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/rope_kvcache_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rope_kvcache_fusion.py) + +### Sequence Parallelism (`enable_sp`) + +**What it fuses.** Replaces all-reduce collectives with reduce-scatter + local RMSNorm + all-gather, +splitting the sequence dimension across TP ranks. This restructures the graph so the subsequent AsyncTP +pass can fuse the reduce-scatter / all-gather with the surrounding GEMMs. + +Sequence Parallelism itself does not directly improve performance; it is a prerequisite for the +AsyncTP pass (`fuse_gemm_comms`). SP is only applied above a minimum token threshold that is +autoconfigured based on device capability and model `hidden_size`. Currently only active on +H100/SM90 for models with `hidden_size >= 8192`. The threshold is configurable via +`PassConfig.sp_min_token_num`. + +The general transformation: + +```text +Input → AllReduce → RMSNorm → Output +becomes: +Input → ReduceScatter → local RMSNorm → AllGather → Output +``` + +Patterns covered: + +- First block: `AllReduce → RMSNorm` → `ReduceScatter → RMSNorm → AllGather` +- Middle blocks: `AllReduce → fused_add_RMSNorm` → `ReduceScatter → fused_add_RMSNorm → AllGather` +- Both with optional `→ FP8 static quant` suffix + +Requires: `use_inductor_graph_partition=True` **or** piecewise compilation with static sizes +divisible by `tensor_parallel_size`. + +Supported hardware: Only tested on NVIDIA CUDA, possibly works on ROCm. FP8 all-gather requires sm90+. + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/sequence_parallelism.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/sequence_parallelism.py) + +### AsyncTP GEMM + Collective Overlap (`fuse_gemm_comms`) + +!!! info + Requires `enable_sp=True` (enabled automatically). This pass is a no-op if Sequence Parallelism has not been applied. + +**What it fuses.** After Sequence Parallelism transforms the graph, fuses GEMM kernels with the +surrounding reduce-scatter (output projection) and all-gather (input projection) using +`torch.ops.symm_mem` symmetric-memory primitives, overlapping communication and computation. +This overlap is only profitable for large `num_tokens`, so the fusion (and preceding SP) +is only performed in the higher compiled range above `PassConfig.sp_min_token_num`. + +Patterns covered: + +- `GEMM → reduce-scatter` → `fused_matmul_reduce_scatter` +- `all-gather → GEMM` → `all_gather_matmul` +- FP8 scaled variants of both patterns + +Supported hardware: NVIDIA CUDA with symmetric-memory (`torch.distributed._symmetric_memory`) support. + +On B200, pattern-matching fp8 FlashInfer scaled MM is not supported, so it must be disabled +([#27893](https://github.com/vllm-project/vllm/issues/27893)) + +```shell +VLLM_DISABLED_KERNELS=FlashInferFP8ScaledMMLinearKernel ... +``` + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/collective_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/collective_fusion.py) +- Sequence parallelism pass: [`vllm/compilation/passes/fusion/sequence_parallelism.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/sequence_parallelism.py) + +### QK Norm + RoPE (`enable_qk_norm_rope_fusion`) + +!!! info + Only applicable to models that apply per-head RMSNorm to Q and K before rotary positional + embedding (e.g. Qwen). Not enabled by default at any optimization level due to perf issues on H100: + [#34391](https://github.com/vllm-project/vllm/issues/34391) + +**What it fuses.** Fuses the sequence: split QKV → reshape → Q/K RMSNorm → reshape → rotary +embedding into a single `fused_qk_norm_rope` CUDA kernel. + +```text +# Unfused: +q, k, v = split(qkv) +q_norm = rms_norm(q.view(heads)) +k_norm = rms_norm(k.view(kv_heads)) +q_rope, k_rope = rotary_embedding(q_norm, k_norm, ...) + +# Fused: +fused_qk_norm_rope(qkv, ...) +``` + +Supported hardware: CUDA (sm80+) only, tested only on sm90 and sm100. + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/qk_norm_rope_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py) +- CUDA kernel: [`csrc/ops.h`](https://github.com/vllm-project/vllm/blob/main/csrc/ops.h) (`fused_qk_norm_rope`) + +### RMSNorm + Quantization (`fuse_norm_quant`) + +!!! warning + On NVIDIA, Inductor actually generates a faster fused kernel than our custom CUDA kernel. + Hence, this fusion is only enabled when either `rms_norm` or `quant_fp8` is using a custom kernel. + +**What it fuses.** Combines the custom `rms_norm` / `fused_add_rms_norm` +operations with subsequent quantization into a single fused kernel, +eliminating an intermediate read/write of the full-precision activation tensor. +Two variants are fused: + +- *Plain RMSNorm + quant*: `rms_norm(x) → quant_fp8(y)` +- *Fused-add RMSNorm + quant*: `fused_add_rms_norm(x, residual) → quant_fp8(y)` — also updates the residual in-place. + +Note that AITER fusions are currently in a separate pass in `vllm.compilation.passes.fusion.rocm_aiter_fusion`. + +Supported quantization scheme/hardware combinations: + +- FP8 static per-tensor: CUDA & HIP kernel +- FP8 dynamic per-token: CUDA & HIP kernel, AITER +- FP8 dynamic per-token-group (128/64): CUDA & HIP kernel, AITER + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/rms_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rms_quant_fusion.py) +- ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) +- CUDA/HIP kernels: [`csrc/layernorm_quant_kernels.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/layernorm_quant_kernels.cu) + +### SiLU+Mul + Quantization (`fuse_act_quant`) + +!!! warning + Same as `fuse_norm_quant`: on NVIDIA, Inductor generates a faster fused kernel than our custom ops. + This fusion is only enabled when either `silu_and_mul` or `quant_fp8` are using a custom kernel, + or for NVFP4-quantized models (where FP4 quant is always a custom op). + +**What it fuses.** Fuses the `silu_and_mul` gate-up projection activation with subsequent quantization into a single kernel, +avoiding materialization of the full-precision post-activation tensor. + +Note that AITER fusions are in a separate pass in `vllm.compilation.passes.fusion.rocm_aiter_fusion`. + +Supported quantization scheme/hardware combinations: + +- FP8 static per-tensor: CUDA & HIP kernel +- NVFP4 dynamic: CUDA sm100+ only with FlashInfer +- FP8 per-token-group (128): ROCm AITER only + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/act_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/act_quant_fusion.py) +- ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) +- CUDA/HIP kernels: [`csrc/quantization/`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/) + +### RMSNorm + Padding (`fuse_act_padding`) + +!!! info + ROCm/AITER-only. Targeted at GPT-OSS models. + +**What it fuses.** Fuses a residual add + RMSNorm with a subsequent padding operation that pads +the hidden dimension to a multiple required by downstream AITER Triton GEMM kernels. + +Requires: AMD ROCm with AITER RMSNorm enabled. Enabled by default in optimization level O1 and above +when the hidden size is 2880 and AITER Triton GEMMs *not* enabled. + +**Code locations.** + +- Pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) (`RocmAiterTritonAddRMSNormPadFusionPass`) + +## See Also + +- [Optimization Levels](optimization_levels.md) — high-level presets that set + fusion defaults. +- [torch.compile in vLLM](torch_compile.md) — how the Inductor pass pipeline + works. +- [Attention Backends](attention_backends.md) — attention-specific kernel + selection. diff --git a/docs/design/optimization_levels.md b/docs/design/optimization_levels.md index 4987c1820..91af515f4 100644 --- a/docs/design/optimization_levels.md +++ b/docs/design/optimization_levels.md @@ -1,29 +1,18 @@ - - # Optimization Levels ## Overview -vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechanism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out-of-the-box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten. +vLLM provides 4 optimization levels (`-O0`, `-O1`, `-O2`, `-O3`) that allow users to trade off startup time for performance: + +- `-O0`: No optimization. Fastest startup time, but lowest performance. +- `-O1`: Fast optimization. Simple compilation and fast fusions, and PIECEWISE cudagraphs. +- `-O2`: Default optimization. Additional compilation ranges, additional fusions, FULL_AND_PIECEWISE cudagraphs. +- `-O3`: Aggressive optimization. Currently equal to `-O2`, but may include additional time-consuming or experimental optimizations in the future. + +All optimization level defaults can be achieved by manually setting the underlying flags. +User-set flags take precedence over optimization level defaults. ## Level Summaries and Usage Examples -```bash -# CLI usage -python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O0 - -# Python API usage -from vllm.entrypoints.llm import LLM - -llm = LLM( - model="RedHatAI/Llama-3.2-1B-FP8", - optimization_level=0 -) -``` - -#### `-O1`: Quick Optimizations -- **Startup**: Moderate startup time -- **Performance**: Inductor compilation, CUDAGraphMode.PIECEWISE -- **Use case**: Balance for most development scenarios ```bash # CLI usage @@ -34,31 +23,59 @@ from vllm.entrypoints.llm import LLM llm = LLM( model="RedHatAI/Llama-3.2-1B-FP8", - optimization_level=1 + optimization_level=2 # equivalent to -O2 ) ``` -#### `-O2`: Full Optimizations (Default) -- **Startup**: Longer startup time -- **Performance**: `-O1` + CUDAGraphMode.FULL_AND_PIECEWISE -- **Use case**: Production workloads where performance is important. This is the default use case. It is also very similar to the previous default. The primary difference is that noop & fusion flags are enabled. +### `-O0`: No Optimization -```bash -# CLI usage (default, so optional) -python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O2 +Startup as fast as possible - no autotuning, no compilation, and no cudagraphs. +This level is good for initial phases of development and debugging. -# Python API usage -from vllm.entrypoints.llm import LLM +Settings: -llm = LLM( - model="RedHatAI/Llama-3.2-1B-FP8", - optimization_level=2 # This is the default -) -``` +- `-cc.cudagraph_mode=NONE` +- `-cc.mode=NONE` (also resulting in `-cc.custom_ops=["none"]`) +- `-cc.pass_config.fuse_...=False` (all fusions disabled) +- `--kernel-config.enable_flashinfer_autotune=False` -#### `-O3`: Full Optimization -Still in development. Added infrastructure to prevent changing API in future -release. Currently behaves the same O2. +### `-O1`: Fast Optimization + +Prioritize fast startup, but still enable basic optimizations like compilation and cudagraphs. +This level is a good balance for most development scenarios where you want faster startup but +still make sure your code does not break cudagraphs or compilation. + +Settings: + +- `-cc.cudagraph_mode=PIECEWISE` +- `-cc.mode=VLLM_COMPILE` +- `--kernel-config.enable_flashinfer_autotune=True` + +Fusions: + +- `-cc.pass_config.fuse_norm_quant=True`* +- `-cc.pass_config.fuse_act_quant=True`* +- `-cc.pass_config.fuse_act_padding=True`† +- `-cc.pass_config.fuse_rope_kvcache=True`† (will be moved to O2) + +\* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.
+† These fusions are ROCm-only and require AITER. + +### `-O2`: Full Optimization (Default) + +Prioritize performance at the expense of additional startup time. +This level is recommended for production workloads and is hence the default. +Fusions in this level _may_ take longer due to additional compile ranges. + +Settings (on top of `-O1`): + +- `-cc.cudagraph_mode=FULL_AND_PIECEWISE` +- `-cc.pass_config.fuse_allreduce_rms=True` + +### `-O3`: Aggressive Optimization + +This level is currently the same as `-O2`, but may include additional optimizations +in the future that are more time-consuming or experimental. ## Troubleshooting @@ -66,4 +83,4 @@ release. Currently behaves the same O2. 1. **Startup Time Too Long**: Use `-O0` or `-O1` for faster startup 2. **Compilation Errors**: Use `debug_dump_path` for additional debugging information -3. **Performance Issues**: Ensure using `-O2` for production \ No newline at end of file +3. **Performance Issues**: Ensure using `-O2` for production diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 59af0109b..bf91fda95 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -4,7 +4,7 @@ import enum from collections import Counter from collections.abc import Callable -from dataclasses import field +from dataclasses import field, fields from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -269,6 +269,24 @@ class PassConfig: ) self.fuse_rope_kvcache = False + def log_enabled_passes(self) -> None: + """ + Log the enabled custom fusion passes. + This is called at the end of VLLMConfig post_init, + after all defaults are finalized. + TODO also log the compile ranges for which this is enabled. + """ + enabled_fusions = [ + f.name[len("fuse_") :] + for f in fields(self) + if getattr(self, f.name) and f.name.startswith("fuse_") + ] + + if enabled_fusions: + logger.info_once( + "Enabled custom fusions: %s", ", ".join(enabled_fusions), scope="global" + ) + class DynamicShapesType(str, enum.Enum): """Types of dynamic shapes handling in torch.compile(). @@ -341,7 +359,8 @@ class CompilationConfig: VLLMConfig's post_init does further initialization. If used outside of the VLLMConfig, some fields will be left in an improper state. - It has three parts: + It contains PassConfig, which controls the custom fusion/transformation passes. + The rest has three parts: - Top-level Compilation control: - [`mode`][vllm.config.CompilationConfig.mode] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a7c431353..34c668362 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1272,6 +1272,9 @@ class VllmConfig: # Handle the KV connector configs self._post_init_kv_transfer_config() + # Log the custom passes that are enabled + self.compilation_config.pass_config.log_enabled_passes() + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when # enable sequence parallelism