Revert #26113 "[Frontend] CompilationConfig overhaul (#20283): deprecate use_inductor in favor of backend, simplify custom_ops" (#26472)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2025-10-09 20:43:55 +08:00
committed by GitHub
parent 92be3f3517
commit 5728da11ea
7 changed files with 63 additions and 126 deletions

View File

@@ -180,11 +180,10 @@ class CompilationConfig:
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
backend: str = "inductor"
backend: str = ""
"""The backend for compilation. It needs to be a string:
- "" (empty string): use the default backend ("inductor" on CUDA-alike
platforms).
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the
@@ -193,11 +192,7 @@ class CompilationConfig:
distributed setting. When the compilation level is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation level is 3, the backend is used for the piecewise compilation
(it sees a part of the graph). The backend can not be custom for compilation
level 3. Furthermore, compilation is only piecewise if splitting ops is set
accordingly and use_inductor_cudagraphs_partition is off. Note that the
default options for splitting ops are sufficient for piecewise compilation.
"""
(it sees a part of the graph)."""
custom_ops: list[str] = field(default_factory=list)
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
to enable all, 'none' to disable all. Also specify a list of custom op
@@ -215,12 +210,8 @@ class CompilationConfig:
compilation."""
# Inductor capture
use_inductor: Optional[bool] = None
"""
Whether to use inductor compilation.
This flag is deprecated and will be removed.
Please use the 'backend' option instead.
use_inductor: bool = True
"""Whether to use inductor compilation:
- False: inductor compilation is not used. graph runs in eager
(custom_ops enabled by default).
@@ -228,11 +219,7 @@ class CompilationConfig:
One graph for symbolic shape and one graph per size in compile_sizes
are compiled using configurations in inductor_compile_config.
This setting is ignored if level<PIECEWISE.
For future compatibility:
If use_inductor is True, backend="inductor" otherwise backend="eager".
"""
This setting is ignored if level<PIECEWISE."""
compile_sizes: Optional[list[Union[int, str]]] = None
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
@@ -538,43 +525,7 @@ class CompilationConfig:
"(where 'op' is the registered op name)"
)
# Currently only eager and inductor backend are supported.
# for piecewise compilation. Custom backends are not suppported for
# piecewise compilation. Update when more backends are supported.
if self.level == CompilationLevel.PIECEWISE and self.backend not in [
"",
"eager",
"inductor",
]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
if self.use_inductor is not None:
logger.warning_once(
"The 'use_inductor' flag is deprecated and will be\
removed in a future release."
"Please use the 'backend' option instead.",
)
self.backend = "inductor" if self.use_inductor else "eager"
if self.backend == "":
self.backend = "inductor"
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
"""
Initialize the backend for the compilation config from a vllm config.
Arguments:
vllm_config: The vllm config to initialize the backend from.
Returns:
The backend for the compilation config.
"""
if self.level is None:
raise ValueError(
"No compilation level is set. This method should only be \
called via vllm config where the level is set if none is \
provided."
)
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@@ -582,15 +533,15 @@ class CompilationConfig:
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
if self.backend not in ["eager", "inductor"]:
raise ValueError(
f"Invalid backend for piecewise compilation: {self.backend}"
)
from vllm.compilation.backends import VllmBackend
@@ -743,7 +694,7 @@ class CompilationConfig:
)
inductor_used = (
self.level == CompilationLevel.PIECEWISE and self.backend == "inductor"
self.level == CompilationLevel.PIECEWISE and self.use_inductor
) or (
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
)