Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -24,14 +24,14 @@ class CompilerInterface:
"""
The interface for a compiler that can be used by vLLM.
"""
# The name of the compiler, e.g. inductor.
# This is a class-level attribute.
name: str
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
@@ -93,12 +93,14 @@ class CompilerInterface:
"""
return None, None
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
) -> Callable:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
@@ -150,11 +152,13 @@ def get_inductor_factors() -> list[Any]:
factors: list[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
return factors
@@ -169,18 +173,19 @@ class InductorStandaloneAdaptor(CompilerInterface):
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
"""
name = "inductor_standalone"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10]
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
self.cache_dir = cache_dir
def compile(
@@ -203,12 +208,14 @@ class InductorStandaloneAdaptor(CompilerInterface):
dynamic_shapes = "from_tracing_context"
from torch._inductor import standalone_compile
with pass_context(runtime_shape):
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config})
options={"config_patches": current_config},
)
# Save the compiled artifact to disk in the specified path
assert key is not None
@@ -218,19 +225,23 @@ class InductorStandaloneAdaptor(CompilerInterface):
compilation_counter.num_compiled_artifacts_saved += 1
return compiled_graph, (key, path)
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
path = handle[1]
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format="unpacked")
path=path, format="unpacked"
)
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args):
@@ -250,21 +261,22 @@ class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10]
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(self,
cache_dir: str,
disable_cache: bool = False,
prefix: str = ""):
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
if disable_cache:
return
# redirect the cache directory to a sub-directory
@@ -288,6 +300,7 @@ class InductorAdaptor(CompilerInterface):
) -> tuple[Optional[Callable], Optional[Any]]:
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
@@ -308,8 +321,8 @@ class InductorAdaptor(CompilerInterface):
# it to get the hash of the compiled graph directly.
hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
@@ -326,7 +339,8 @@ class InductorAdaptor(CompilerInterface):
if not callable(cell.cell_contents):
continue
if cell.cell_contents.__code__.co_filename.startswith(
self.base_cache_dir):
self.base_cache_dir
):
# this is the real file path compiled from Inductor
file_path = cell.cell_contents.__code__.co_filename
break
@@ -338,8 +352,7 @@ class InductorAdaptor(CompilerInterface):
original_load_name = None
def hijacked_compile_fx_inner(*args, **kwargs):
output = torch._inductor.compile_fx.compile_fx_inner(
*args, **kwargs)
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
if inductor_compiled_graph is not None:
@@ -353,8 +366,7 @@ class InductorAdaptor(CompilerInterface):
if not callable(cell.cell_contents):
continue
code = cell.cell_contents.__code__
if code.co_filename.startswith(
self.base_cache_dir):
if code.co_filename.startswith(self.base_cache_dir):
# this is the real file path
# compiled from Inductor
file_path = code.co_filename
@@ -387,29 +399,38 @@ class InductorAdaptor(CompilerInterface):
# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash))
patch(
"torch._inductor.codecache.compiled_fx_graph_hash",
hijack_compiled_fx_graph_hash,
)
)
# for providing a dummy shape environment
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env,
)
)
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env))
_get_shape_env,
)
)
# for forcing the graph to be cached
stack.enter_context(
patch(
"torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache))
_check_can_cache,
)
)
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
@@ -422,23 +443,26 @@ class InductorAdaptor(CompilerInterface):
# standalone_compile sometime.
if is_torch_equal_or_newer("2.6"):
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False))
torch._inductor.config.patch(fx_graph_remote_cache=False)
)
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False))
torch._functorch.config.patch(enable_autograd_cache=False)
)
stack.enter_context(
torch._functorch.config.patch(
enable_remote_autograd_cache=False))
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
with pass_context(runtime_shape):
compiled_graph = compile_fx(
graph,
example_inputs,
inner_compile=hijacked_compile_fx_inner,
config_patches=current_config)
config_patches=current_config,
)
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
# compilation cache. So turn off the checks if we disable the
@@ -451,52 +475,63 @@ class InductorAdaptor(CompilerInterface):
"failed, leading to a corrupted compilation artifact. "
"We recommend trying to "
"remove ~/.cache/vllm/torch_compile_cache and try again "
"to see the real issue. ")
"to see the real issue. "
)
assert file_path is not None, (
"failed to get the file path of the compiled graph")
"failed to get the file path of the compiled graph"
)
return compiled_graph, (hash_str, file_path)
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
def load(
self,
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None,
) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
patch(
"torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
lambda *args, **kwargs: AlwaysHitShapeEnv(),
)
)
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
if torch.__version__.startswith("2.5"):
inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False)
hash_str, example_inputs, True, False
)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
)
elif torch.__version__ >= "2.6":
from torch._inductor.output_code import (
CompiledFxGraphConstantsWithGm)
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
constants = CompiledFxGraphConstantsWithGm(graph)
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, None, constants)
hash_str, example_inputs, True, None, constants
)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache directory and try again." # noqa
@@ -509,6 +544,7 @@ class InductorAdaptor(CompilerInterface):
# need to know if the graph returns a tuple
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
@@ -542,6 +578,7 @@ class InductorAdaptor(CompilerInterface):
"""
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
else:
return contextlib.nullcontext()
@@ -553,7 +590,8 @@ def set_inductor_config(config, runtime_shape):
# can be beneficial
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
config["coordinate_descent_tuning"] = (
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
)
class EagerAdaptor(CompilerInterface):