[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
fhl2000
2025-08-15 22:01:39 +08:00
committed by GitHub
parent a0632a3e03
commit 74f441f4b5
34 changed files with 1839 additions and 597 deletions

View File

@@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import hashlib
from collections import Counter
from dataclasses import asdict, field
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
from pydantic import TypeAdapter
from pydantic import TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
@@ -31,6 +32,40 @@ class CompilationLevel:
PIECEWISE = 3
class CUDAGraphMode(enum.Enum):
""" Constants for the cudagraph mode in CompilationConfig.
Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also
treated as concrete runtime mode for cudagraph runtime dispatching.
"""
NONE = 0
PIECEWISE = 1
FULL = 2
FULL_DECODE_ONLY = (FULL, NONE)
FULL_AND_PIECEWISE = (FULL, PIECEWISE)
def decode_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[0]) if \
self.separate_routine() else self
def mixed_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(self.value[1]) if \
self.separate_routine() else self
def requires_piecewise_compilation(self) -> bool:
return (self.decode_mode() == CUDAGraphMode.PIECEWISE
or self.mixed_mode() == CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(max(
self.value)) if self.separate_routine() else self
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
@config
@dataclass
class PassConfig:
@@ -91,6 +126,7 @@ class CompilationConfig:
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`cudagraph_num_of_warmups`]
@@ -157,7 +193,7 @@ class CompilationConfig:
By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] = field(default_factory=list)
splitting_ops: Optional[list[str]] = None
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
@@ -187,7 +223,43 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1)
cudagraph_mode: Optional[CUDAGraphMode] = None
"""
The mode of the cudagraph.
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
- FULL_DECODE_ONLY.
- FULL_AND_PIECEWISE.
PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph
incompatiable ops (i.e. some attention ops) outside the cudagraph
for general flexibility.
This is the default mode.
FULL mode: Capture full cudagraph for all batches. Can be good for small
models or workloads with small prompts; not supported by many backends.
Generally for performance FULL_AND_PIECEWISE is better.
FULL_DECODE_ONLY mode: Capture full cudagraph for decode batches only.
Mixed prefill-decode batches are run without cudagraphs. Can be good for
decode instances in a P/D setup where prefill is not as important so we
can save some memory.
FULL_AND_PIECEWISE mode: Capture full cudagraph for decode batches and
piecewise cudagraph for prefill and mixed prefill-decode batches.
This is like the most performant mode for most models.
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (level=PIECEWISE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
@@ -197,8 +269,9 @@ class CompilationConfig:
CompilationLevel.PIECEWISE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future."""
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
@@ -213,12 +286,17 @@ class CompilationConfig:
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False."""
full_cuda_graph: bool = False
internally managed buffer. Default is False.
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: Optional[bool] = False
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
@@ -253,6 +331,13 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
# Attention ops; used for piecewise cudagraphs
_attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@@ -297,13 +382,26 @@ class CompilationConfig:
if pass_config_exclude:
exclude["pass_config"] = pass_config_exclude
return TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode()
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
__str__ = __repr__
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
"""
enable parse the `cudagraph_mode` enum type from string
"""
if isinstance(value, str):
return CUDAGraphMode[value.upper()]
return value
def __post_init__(self) -> None:
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
@@ -341,7 +439,26 @@ class CompilationConfig:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
# migrate the deprecated flags
if not self.use_cudagraph:
logger.warning("use_cudagraph is deprecated, use "
"cudagraph_mode=NONE instead.")
if self.cudagraph_mode is not None:
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
"use_cudagraph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.NONE
if self.full_cuda_graph:
logger.warning("full_cuda_graph is deprecated, use "
"cudagraph_mode=FULL instead.")
if self.cudagraph_mode is not None:
raise ValueError("full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@@ -414,15 +531,34 @@ class CompilationConfig:
self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")
# NOTE: this function needs to be called only when level is
# CompilationLevel.PIECEWISE
assert self.level == CompilationLevel.PIECEWISE, (
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE")
if not self.splitting_ops:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture the
# full cudagraph outside the fx graph. This reduces some cpu
# overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
self.splitting_ops = self._attention_ops
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
"treated as FULL cudagraph_mode. Please ensure you are "
"using attention backends that support cudagraph or set "
"cudagraph_mode to NONE explicitly if encountering "
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)