[Frontend][torch.compile] CompilationConfig Overhaul (#20283): name change compilation level to compilation mode, deprecation compilation level (#26355)

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: Morrison Turnansky <mturnans@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Morrison Turnansky
2025-10-14 22:51:16 -04:00
committed by GitHub
parent e66d787bce
commit 96b9aa5aa0
42 changed files with 270 additions and 248 deletions

View File

@@ -26,12 +26,20 @@ else:
logger = init_logger(__name__)
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
PIECEWISE = 3
class CompilationMode:
"""The compilation approach used for torch.compile-based compilation of the
model."""
NONE = 0
"""No torch.compile compilation is applied, model runs in fully eager pytorch mode.
The model runs as-is."""
STOCK_TORCH_COMPILE = 1
"""The standard `torch.compile` compilation pipeline."""
DYNAMO_TRACE_ONCE = 2
"""Single Dynamo trace through the model, avoiding recompilation."""
VLLM_COMPILE = 3
"""Custom vLLM Inductor-based backend with caching, piecewise compilation,
shape specialization, and custom passes."""
class CUDAGraphMode(enum.Enum):
@@ -134,7 +142,7 @@ class CompilationConfig:
"""Configuration for compilation. It has three parts:
- Top-level Compilation control:
- [`level`][vllm.config.CompilationConfig.level]
- [`mode`][vllm.config.CompilationConfig.mode]
- [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
- [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
- [`backend`][vllm.config.CompilationConfig.backend]
@@ -171,14 +179,26 @@ class CompilationConfig:
# Top-level Compilation control
level: int | None = None
"""The level of compilation:
"""
Level is deprecated and will be removed in the next release,
either 0.12.0 or 0.11.2 whichever is soonest.
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
mode: int | None = None
"""The compilation approach used for torch.compile-based compilation of the
model.
- None: If None, we will select the default compilation level.
For V1 engine this is 3, for V0 engine this is 0.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation."""
- None: If None, we will select the default compilation mode.
For V1 engine this is 3.
- 0: NONE: No torch.compile compilation is applied, model runs in fully
eager pytorch mode. The model runs as-is.
- 1: STOCK_TORCH_COMPILE: The standard `torch.compile` compilation pipeline.
- 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding
recompilation by removing guards.
Requires no dynamic-shape-dependent control-flow.
- 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching,
piecewise compilation, shape specialization, and custom passes."""
debug_dump_path: Path | None = None
"""The path to dump the debug information."""
cache_dir: str = ""
@@ -195,11 +215,11 @@ class CompilationConfig:
backend function.
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation level is 1 or 2, the backend is
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
compilation mode is 3, the backend is used for the piecewise compilation
(it sees a part of the graph). The backend can not be custom for compilation
level 3, i.e. the backend must be either eager or inductor. Furthermore,
mode 3, i.e. the backend must be either eager or inductor. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
splitting ops are sufficient for piecewise compilation.
@@ -214,7 +234,7 @@ class CompilationConfig:
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor and
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: list[str] | None = None
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
@@ -249,7 +269,7 @@ class CompilationConfig:
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE.
This setting is ignored if mode<VLLM_COMPILE.
For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
@@ -299,7 +319,7 @@ class CompilationConfig:
Currently, the cudagraph mode is only used for the v1 engine.
Note that the cudagraph logic is generally orthogonal to the
compilation logic. While piecewise cudagraphs require piecewise
compilation (level=PIECEWISE and non-empty splitting_ops), full
compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full
cudagraphs are supported with and without compilation.
Warning: This flag is new and subject to change in addition
@@ -312,7 +332,7 @@ class CompilationConfig:
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
In the vLLM V1 Engine, this flag only applies for
CompilationLevel.PIECEWISE (aka -O3).
CompilationMode.VLLM_COMPILE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
Warning: This flag is deprecated and will be removed in the next major or
@@ -426,7 +446,7 @@ class CompilationConfig:
the final hidden states.
"""
factors: list[Any] = []
factors.append(self.level)
factors.append(self.mode)
factors.append(self.backend)
factors.append(self.custom_ops)
factors.append(self.splitting_ops)
@@ -477,6 +497,17 @@ class CompilationConfig:
return value
def __post_init__(self) -> None:
if self.level is not None:
logger.warning(
"Level is deprecated and will be removed in the next release,"
"either 0.12.0 or 0.11.2 whichever is soonest."
"Use mode instead."
"If both level and mode are given,"
"only mode will be used."
)
if self.mode is None:
self.mode = self.level
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
@@ -574,7 +605,7 @@ class CompilationConfig:
# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [
"",
"eager",
"inductor",
@@ -602,24 +633,27 @@ class CompilationConfig:
Returns:
The backend for the compilation config.
"""
if self.level is None:
if self.mode is None:
raise ValueError(
"No compilation level is set. This method should only be \
"No compilation mode is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
if self.mode == CompilationMode.NONE:
raise ValueError("No compilation mode is set.")
from torch._dynamo.backends.registry import list_backends
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.mode in [
CompilationMode.STOCK_TORCH_COMPILE,
CompilationMode.DYNAMO_TRACE_ONCE,
]:
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
assert self.level == CompilationLevel.PIECEWISE
assert self.mode == CompilationMode.VLLM_COMPILE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
@@ -684,11 +718,11 @@ class CompilationConfig:
self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called only when level is
# CompilationLevel.PIECEWISE
assert self.level == CompilationLevel.PIECEWISE, (
# NOTE: this function needs to be called only when mode is
# CompilationMode.VLLM_COMPILE
assert self.mode == CompilationMode.VLLM_COMPILE, (
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE"
"mode is CompilationMode.VLLM_COMPILE"
)
if self.use_inductor_graph_partition:
@@ -769,12 +803,10 @@ class CompilationConfig:
if not self.use_inductor_graph_partition:
# Dynamo-level FX split case
return self.level == CompilationLevel.PIECEWISE
return self.mode == CompilationMode.VLLM_COMPILE
# Inductor partition case
return (
self.backend == "inductor" and self.level > CompilationLevel.NO_COMPILATION
)
return self.backend == "inductor" and self.mode > CompilationMode.NONE
def custom_op_log_check(self):
"""