[Misc][BE] Type coverage for vllm/compilation [2/3] (#31744)
This commit is contained in:
@@ -179,7 +179,7 @@ class CompilerManager:
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable | None:
|
||||
) -> Callable[..., Any] | None:
|
||||
if (compile_range, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
|
||||
@@ -199,7 +199,7 @@ class CompilerManager:
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
additional_inductor_config,
|
||||
additional_inductor_config: dict[str, Any],
|
||||
compilation_config: CompilationConfig,
|
||||
compile_range: Range,
|
||||
graph_index: int = 0,
|
||||
@@ -355,7 +355,7 @@ def split_graph(
|
||||
compilation_start_time = 0.0
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
@@ -506,9 +506,9 @@ class VllmBackend:
|
||||
# the stiching graph module for all the piecewise graphs
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: list[SplitItem]
|
||||
returned_callable: Callable
|
||||
returned_callable: Callable[..., Any]
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
post_grad_passes: Sequence[Callable[..., Any]]
|
||||
sym_tensor_indices: list[int]
|
||||
input_buffers: list[torch.Tensor]
|
||||
compiler_manager: CompilerManager
|
||||
@@ -821,7 +821,7 @@ class VllmBackend:
|
||||
]
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def copy_and_call(*args):
|
||||
def copy_and_call(*args: Any) -> Any:
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(self.sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -25,7 +27,7 @@ assert isinstance(SerializableCallable, type)
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable):
|
||||
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
inputs to the compiled function and return the result.
|
||||
@@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False
|
||||
):
|
||||
self,
|
||||
graph_module: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[Any],
|
||||
prefix: str,
|
||||
optimized_call: Callable[..., Any],
|
||||
is_encoder: bool = False,
|
||||
) -> None:
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
self.graph_module = graph_module
|
||||
self.example_inputs = example_inputs
|
||||
@@ -53,7 +60,7 @@ class VllmSerializableFunction(SerializableCallable):
|
||||
if sym_input is not None:
|
||||
self.shape_env = sym_input.node.shape_env
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.optimized_call(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@@ -73,7 +80,9 @@ class VllmSerializableFunction(SerializableCallable):
|
||||
|
||||
graph_reducer_override = GraphPickler.reducer_override
|
||||
|
||||
def _graph_reducer_override(self, obj):
|
||||
def _graph_reducer_override(
|
||||
self: GraphPickler, obj: Any
|
||||
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, sympy.Function)
|
||||
@@ -114,7 +123,7 @@ class VllmSerializableFunction(SerializableCallable):
|
||||
get_current_vllm_config(), state["prefix"], is_encoder
|
||||
)
|
||||
|
||||
def optimized_call(*example_inputs):
|
||||
def optimized_call(*example_inputs: Any) -> Any:
|
||||
"""
|
||||
On the first run of the optimized call, we rerun the compiler
|
||||
backend which should result in a cache hit. After the backend
|
||||
@@ -136,7 +145,7 @@ class VllmSerializableFunction(SerializableCallable):
|
||||
return fn
|
||||
|
||||
@property
|
||||
def co_name(self):
|
||||
def co_name(self) -> Literal["VllmSerializableFunction"]:
|
||||
"""
|
||||
Used for depyf debugging.
|
||||
"""
|
||||
|
||||
@@ -42,7 +42,9 @@ class CUDAGraphLogging:
|
||||
"Count",
|
||||
]
|
||||
|
||||
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
|
||||
def __init__(
|
||||
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
|
||||
) -> None:
|
||||
self.reset()
|
||||
self.cg_mode = str(cg_mode)
|
||||
self.cg_capture_sizes = str(cg_capture_sizes or [])
|
||||
@@ -54,10 +56,10 @@ class CUDAGraphLogging:
|
||||
"**CUDAGraph Stats:**\n\n"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.stats = []
|
||||
def reset(self) -> None:
|
||||
self.stats: list[CUDAGraphStat] = []
|
||||
|
||||
def observe(self, cudagraph_stat: CUDAGraphStat):
|
||||
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
|
||||
self.stats.append(cudagraph_stat)
|
||||
|
||||
def generate_metric_table(self) -> str:
|
||||
@@ -109,7 +111,7 @@ class CUDAGraphLogging:
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
|
||||
if not self.stats:
|
||||
return
|
||||
log_fn(self.generate_metric_table())
|
||||
@@ -161,11 +163,11 @@ class CUDAGraphWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
@@ -189,7 +191,7 @@ class CUDAGraphWrapper:
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
@@ -198,11 +200,11 @@ class CUDAGraphWrapper:
|
||||
f"cudagraph wrapper: {self.runnable}"
|
||||
)
|
||||
|
||||
def unwrap(self) -> Callable:
|
||||
def unwrap(self) -> Callable[..., Any]:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
@@ -6,8 +6,8 @@ import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar, overload
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -32,6 +32,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Only added on nightly/2.10 so wrap
|
||||
try:
|
||||
from torch._dynamo.package import SourceInfo
|
||||
except ImportError:
|
||||
# Fallback for old versions not supporting
|
||||
SourceInfo = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
@@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls) -> bool:
|
||||
def _should_ignore_torch_compile(cls: _T) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
@@ -224,7 +232,7 @@ def support_torch_compile(
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _model_hash_key(fn) -> str:
|
||||
def _model_hash_key(fn: Callable[..., Any]) -> str:
|
||||
import vllm
|
||||
|
||||
sha256_hash = hashlib.sha256()
|
||||
@@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str:
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def _verify_source_unchanged(source_info, vllm_config) -> None:
|
||||
def _verify_source_unchanged(
|
||||
source_info: "SourceInfo", vllm_config: VllmConfig
|
||||
) -> None:
|
||||
from .caching import _compute_code_hash, _compute_code_hash_with_content
|
||||
|
||||
file_contents = {}
|
||||
@@ -275,8 +285,12 @@ def _support_torch_compile(
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
|
||||
):
|
||||
self: _T,
|
||||
*,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if vllm_config is None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
@@ -309,13 +323,17 @@ def _support_torch_compile(
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type]
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
|
||||
def mark_dynamic(arg, dims):
|
||||
if type == DynamicShapesType.UNBACKED:
|
||||
def _mark_dynamic_inputs(
|
||||
mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
@@ -326,7 +344,7 @@ def _support_torch_compile(
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
sig = inspect.signature(mod.__class__.forward)
|
||||
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
@@ -364,7 +382,7 @@ def _support_torch_compile(
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self: _T, *args: Any, **kwargs: Any) -> Any:
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
@@ -444,7 +462,7 @@ def _support_torch_compile(
|
||||
not envs.VLLM_USE_AOT_COMPILE
|
||||
or self.vllm_config.compilation_config.backend == "eager"
|
||||
)
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
# This is the path for the first compilation.
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
@@ -477,7 +495,7 @@ def _support_torch_compile(
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
|
||||
def patched_inline_call(self_):
|
||||
def patched_inline_call(self_: Any) -> Any:
|
||||
code = self_.f_code
|
||||
self.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
@@ -535,7 +553,7 @@ def _support_torch_compile(
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
@@ -545,7 +563,9 @@ def _support_torch_compile(
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
def maybe_use_cudagraph_partition_wrapper(
|
||||
vllm_config: VllmConfig,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager to set/unset customized cudagraph partition wrappers.
|
||||
|
||||
@@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
|
||||
def customized_cudagraph_wrapper(
|
||||
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
|
||||
) -> Any:
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
@@ -600,7 +622,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _torch27_patch_tensor_subclasses():
|
||||
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
|
||||
"""
|
||||
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||
@@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses():
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
|
||||
|
||||
@@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
@@ -179,7 +179,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]):
|
||||
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
@@ -194,7 +194,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, torch.fx.Node | str],
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
@@ -206,7 +206,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
|
||||
def replace_users_with_mutated_args(
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
@@ -237,7 +237,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
|
||||
@@ -29,6 +29,9 @@ else:
|
||||
Torch25CustomGraphPass as CustomGraphPass,
|
||||
)
|
||||
|
||||
# Re-export CustomGraphPass for external usage
|
||||
__all__ = ["CustomGraphPass"]
|
||||
|
||||
_pass_context = None
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
@@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
@@ -117,7 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
2. The dimensions both correspond to the same SymInt
|
||||
"""
|
||||
# Case 1
|
||||
return statically_known_true(dim == i_dim)
|
||||
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
|
||||
|
||||
def all_dims_equivalent(
|
||||
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
@@ -40,8 +42,11 @@ from .noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def with_pattern_match_debug(fn):
|
||||
|
||||
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
@@ -49,7 +54,7 @@ def with_pattern_match_debug(fn):
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
@@ -59,7 +64,7 @@ def with_pattern_match_debug(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass):
|
||||
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
@@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass):
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph):
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
compile_range = get_pass_context().compile_range
|
||||
@@ -98,7 +103,7 @@ class PostGradPassManager(CustomGraphPass):
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
def configure(self, config: VllmConfig) -> None:
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
@@ -135,23 +140,25 @@ class PostGradPassManager(CustomGraphPass):
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
def add(self, pass_: InductorPass) -> None:
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self):
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
|
||||
passes = []
|
||||
|
||||
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
passes.append(pass_.uuid())
|
||||
passes.append(self.fix_functionalization.uuid())
|
||||
|
||||
# Include the compile range in the uuid to ensure that inductor
|
||||
# recompiles the graph for the new dynamic compile range.
|
||||
state["compile_range"] = str(get_pass_context().compile_range)
|
||||
|
||||
state["passes"] = passes
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
@@ -86,27 +86,36 @@ class PiecewiseBackend:
|
||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for size in self.compile_sizes:
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
if self.compile_sizes is not None:
|
||||
for size in self.compile_sizes:
|
||||
if isinstance(size, str):
|
||||
assert size == "cudagraph_capture_sizes"
|
||||
raise NotImplementedError(
|
||||
"cudagraph_capture_sizes not supported in compile_sizes."
|
||||
"This should be handled in `post_init_cudagraph_sizes`."
|
||||
)
|
||||
else:
|
||||
assert isinstance(size, int)
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
|
||||
for range in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
def check_for_ending_compilation(self) -> None:
|
||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def _fakify_args(self, args: list[Any]) -> list[Any]:
|
||||
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
|
||||
# We need to pass fake example_inputs, otherwise torch.compile
|
||||
# will fakify the example_inputs potentially causing some non dynamic
|
||||
# dimension to be be duck shaped to other existing shapes that have hints
|
||||
@@ -127,7 +136,9 @@ class PiecewiseBackend:
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
|
||||
def _maybe_compile_for_range_entry(
|
||||
self, range_entry: RangeEntry, args: tuple[Any, ...]
|
||||
) -> Any:
|
||||
if not range_entry.compiled:
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
@@ -136,14 +147,14 @@ class PiecewiseBackend:
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args = (
|
||||
args_list = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else args
|
||||
else list(args)
|
||||
)
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
@@ -153,10 +164,13 @@ class PiecewiseBackend:
|
||||
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
|
||||
# First we try to find the range entry for the concrete compile size
|
||||
# If not found, we search for the range entry
|
||||
# that contains the runtime shape.
|
||||
if self.compile_sizes is None:
|
||||
return None
|
||||
|
||||
if runtime_shape in self.compile_sizes:
|
||||
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
|
||||
else:
|
||||
@@ -165,7 +179,7 @@ class PiecewiseBackend:
|
||||
return self.range_entries[range]
|
||||
return None
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
range_entry = self._find_range_for_shape(runtime_shape)
|
||||
|
||||
|
||||
@@ -4,9 +4,10 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
@@ -19,19 +20,26 @@ from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
def _noop_add_global_state_guard(self, *args, **kwargs):
|
||||
|
||||
def _noop_add_global_state_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
|
||||
def _noop_add_torch_function_mode_stack_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compilation_context():
|
||||
def _compilation_context() -> Generator[None, None, None]:
|
||||
"""Context manager for compilation settings and patches.
|
||||
|
||||
This manager:
|
||||
@@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def check_invariants_and_forward(self, *args, **kwargs):
|
||||
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
assert hasattr(self, "_check_shape_invariants")
|
||||
self._check_shape_invariants(*args, **kwargs)
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
|
||||
def _call_with_optional_nvtx_range(
|
||||
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> Any:
|
||||
if self.layerwise_nvtx_tracing_enabled:
|
||||
args_list = list(args)
|
||||
kwargs_dict = dict(kwargs)
|
||||
@@ -108,7 +118,7 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
return ctx.result
|
||||
return callable_fn(*args, **kwargs)
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -192,9 +202,9 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
self._compiled_bytecode = None
|
||||
self._compiled_bytecode: CodeType | None = None
|
||||
|
||||
def aot_compile(self, *args, **kwargs):
|
||||
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||
raise RuntimeError(
|
||||
"aot_compile is not supported by the current configuration. "
|
||||
@@ -203,7 +213,7 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
)
|
||||
return self._compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
@@ -236,13 +246,13 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs): ...
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def original_code_object(self) -> CodeType:
|
||||
"""Return the original code object of the forward method."""
|
||||
return self.__class__.forward.__code__
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object():
|
||||
return
|
||||
@@ -299,7 +309,7 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def _dispatch_to_compiled_code(self):
|
||||
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
|
||||
# noqa: E501
|
||||
"""
|
||||
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||
|
||||
Reference in New Issue
Block a user