[Misc][BE] Type coverage for vllm/compilation [2/3] (#31744)

This commit is contained in:
Lucas Kabela
2026-01-09 15:30:38 -08:00
committed by GitHub
parent 3adffd5b90
commit aaf4b70aae
12 changed files with 161 additions and 91 deletions

View File

@@ -179,7 +179,7 @@ class CompilerManager:
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable | None:
) -> Callable[..., Any] | None:
if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
@@ -199,7 +199,7 @@ class CompilerManager:
self,
graph: fx.GraphModule,
example_inputs: list[Any],
additional_inductor_config,
additional_inductor_config: dict[str, Any],
compilation_config: CompilationConfig,
compile_range: Range,
graph_index: int = 0,
@@ -355,7 +355,7 @@ def split_graph(
compilation_start_time = 0.0
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
@@ -506,9 +506,9 @@ class VllmBackend:
# the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule
piecewise_graphs: list[SplitItem]
returned_callable: Callable
returned_callable: Callable[..., Any]
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
post_grad_passes: Sequence[Callable[..., Any]]
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager
@@ -821,7 +821,7 @@ class VllmBackend:
]
# this is the callable we return to Dynamo to run
def copy_and_call(*args):
def copy_and_call(*args: Any) -> Any:
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]

View File

@@ -4,6 +4,8 @@
import inspect
import os
import pickle
from collections.abc import Callable, Sequence
from typing import Any, Literal
from unittest.mock import patch
import torch
@@ -25,7 +27,7 @@ assert isinstance(SerializableCallable, type)
logger = init_logger(__name__)
class VllmSerializableFunction(SerializableCallable):
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result.
@@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable):
"""
def __init__(
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False
):
self,
graph_module: torch.fx.GraphModule,
example_inputs: Sequence[Any],
prefix: str,
optimized_call: Callable[..., Any],
is_encoder: bool = False,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
@@ -53,7 +60,7 @@ class VllmSerializableFunction(SerializableCallable):
if sym_input is not None:
self.shape_env = sym_input.node.shape_env
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.optimized_call(*args, **kwargs)
@classmethod
@@ -73,7 +80,9 @@ class VllmSerializableFunction(SerializableCallable):
graph_reducer_override = GraphPickler.reducer_override
def _graph_reducer_override(self, obj):
def _graph_reducer_override(
self: GraphPickler, obj: Any
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
if (
inspect.isclass(obj)
and issubclass(obj, sympy.Function)
@@ -114,7 +123,7 @@ class VllmSerializableFunction(SerializableCallable):
get_current_vllm_config(), state["prefix"], is_encoder
)
def optimized_call(*example_inputs):
def optimized_call(*example_inputs: Any) -> Any:
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
@@ -136,7 +145,7 @@ class VllmSerializableFunction(SerializableCallable):
return fn
@property
def co_name(self):
def co_name(self) -> Literal["VllmSerializableFunction"]:
"""
Used for depyf debugging.
"""

View File

@@ -42,7 +42,9 @@ class CUDAGraphLogging:
"Count",
]
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
def __init__(
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
) -> None:
self.reset()
self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or [])
@@ -54,10 +56,10 @@ class CUDAGraphLogging:
"**CUDAGraph Stats:**\n\n"
)
def reset(self):
self.stats = []
def reset(self) -> None:
self.stats: list[CUDAGraphStat] = []
def observe(self, cudagraph_stat: CUDAGraphStat):
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
self.stats.append(cudagraph_stat)
def generate_metric_table(self) -> str:
@@ -109,7 +111,7 @@ class CUDAGraphLogging:
+ "\n"
)
def log(self, log_fn=logger.info):
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
if not self.stats:
return
log_fn(self.generate_metric_table())
@@ -161,11 +163,11 @@ class CUDAGraphWrapper:
def __init__(
self,
runnable: Callable,
runnable: Callable[..., Any],
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
) -> None:
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
@@ -189,7 +191,7 @@ class CUDAGraphWrapper:
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
def __getattr__(self, key: str):
def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
@@ -198,11 +200,11 @@ class CUDAGraphWrapper:
f"cudagraph wrapper: {self.runnable}"
)
def unwrap(self) -> Callable:
def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable.
return self.runnable
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode

View File

@@ -6,8 +6,8 @@ import hashlib
import inspect
import os
import sys
from collections.abc import Callable
from typing import TypeVar, overload
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
from unittest.mock import patch
import torch
@@ -32,6 +32,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from .monitor import start_monitoring_torch_compile
if TYPE_CHECKING:
# Only added on nightly/2.10 so wrap
try:
from torch._dynamo.package import SourceInfo
except ImportError:
# Fallback for old versions not supporting
SourceInfo = Any
logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
@@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return cls
def _should_ignore_torch_compile(cls) -> bool:
def _should_ignore_torch_compile(cls: _T) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
@@ -224,7 +232,7 @@ def support_torch_compile(
return cls_decorator_helper
def _model_hash_key(fn) -> str:
def _model_hash_key(fn: Callable[..., Any]) -> str:
import vllm
sha256_hash = hashlib.sha256()
@@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str:
return sha256_hash.hexdigest()
def _verify_source_unchanged(source_info, vllm_config) -> None:
def _verify_source_unchanged(
source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
from .caching import _compute_code_hash, _compute_code_hash_with_content
file_contents = {}
@@ -275,8 +285,12 @@ def _support_torch_compile(
setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__(
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
):
self: _T,
*,
vllm_config: VllmConfig | None = None,
prefix: str = "",
**kwargs: Any,
) -> None:
if vllm_config is None:
vllm_config = get_current_vllm_config()
@@ -309,13 +323,17 @@ def _support_torch_compile(
compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type]
cls.__init__ = __init__
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED:
def _mark_dynamic_inputs(
mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any
) -> None:
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
if ds_type == DynamicShapesType.UNBACKED:
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
@@ -326,7 +344,7 @@ def _support_torch_compile(
else:
torch._dynamo.mark_dynamic(arg, dims)
sig = inspect.signature(mod.__class__.forward)
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
@@ -364,7 +382,7 @@ def _support_torch_compile(
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self, *args, **kwargs):
def __call__(self: _T, *args: Any, **kwargs: Any) -> Any:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
@@ -444,7 +462,7 @@ def _support_torch_compile(
not envs.VLLM_USE_AOT_COMPILE
or self.vllm_config.compilation_config.backend == "eager"
)
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
@@ -477,7 +495,7 @@ def _support_torch_compile(
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call_
def patched_inline_call(self_):
def patched_inline_call(self_: Any) -> Any:
code = self_.f_code
self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
@@ -535,7 +553,7 @@ def _support_torch_compile(
str(e),
)
else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
self.compiled = True
return output
@@ -545,7 +563,9 @@ def _support_torch_compile(
@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
def maybe_use_cudagraph_partition_wrapper(
vllm_config: VllmConfig,
) -> Generator[None, None, None]:
"""
Context manager to set/unset customized cudagraph partition wrappers.
@@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
current_platform.get_static_graph_wrapper_cls()
)
def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
def customized_cudagraph_wrapper(
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
) -> Any:
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
@@ -600,7 +622,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
@@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses():
_ColumnvLLMParameter,
)
def return_false(*args, **kwargs):
def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
return False
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):

View File

@@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
# XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu():
@@ -179,7 +179,7 @@ class FixFunctionalizationPass(VllmInductorPass):
)
self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]):
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
"""
Stage a node (or nodes) for removal at the end of the pass.
"""
@@ -194,7 +194,7 @@ class FixFunctionalizationPass(VllmInductorPass):
node: torch.fx.Node,
mutated_args: dict[int, torch.fx.Node | str],
args: tuple[torch.fx.Node | str, ...] | None = None,
):
) -> None:
"""
De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments.
@@ -206,7 +206,7 @@ class FixFunctionalizationPass(VllmInductorPass):
def replace_users_with_mutated_args(
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
):
) -> None:
"""
Replace all getitem users of the auto-functionalized node with the
mutated arguments.
@@ -237,7 +237,7 @@ class FixFunctionalizationPass(VllmInductorPass):
graph: torch.fx.Graph,
node: torch.fx.Node,
args: tuple[torch.fx.Node | str, ...] | None = None,
):
) -> None:
"""
Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly,

View File

@@ -29,6 +29,9 @@ else:
Torch25CustomGraphPass as CustomGraphPass,
)
# Re-export CustomGraphPass for external usage
__all__ = ["CustomGraphPass"]
_pass_context = None
P = ParamSpec("P")
R = TypeVar("R")

View File

@@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass):
"""
@VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
@@ -117,7 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
2. The dimensions both correspond to the same SymInt
"""
# Case 1
return statically_known_true(dim == i_dim)
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
from torch import fx as fx
@@ -40,8 +42,11 @@ from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def with_pattern_match_debug(fn):
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
"""
Function decorator that turns on inductor pattern match debug
for the duration of the call.
@@ -49,7 +54,7 @@ def with_pattern_match_debug(fn):
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
@@ -59,7 +64,7 @@ def with_pattern_match_debug(fn):
return wrapper
class PostGradPassManager(CustomGraphPass):
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
@@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass):
This way, all passes operate on a functionalized graph.
"""
def __init__(self):
def __init__(self) -> None:
self.passes: list[InductorPass] = []
@with_pattern_match_debug
def __call__(self, graph: fx.Graph):
def __call__(self, graph: fx.Graph) -> None:
VllmInductorPass.dump_prefix = 0 # reset dump index
compile_range = get_pass_context().compile_range
@@ -98,7 +103,7 @@ class PostGradPassManager(CustomGraphPass):
self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig):
def configure(self, config: VllmConfig) -> None:
self.pass_config = config.compilation_config.pass_config
# Set the current vllm config to allow tracing CustomOp instances
@@ -135,23 +140,25 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass):
def add(self, pass_: InductorPass) -> None:
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def uuid(self):
def uuid(self) -> str:
"""
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {"pass_config": self.pass_config.compute_hash(), "passes": []}
passes = []
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
passes.append(pass_.uuid())
passes.append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range)
state["passes"] = passes
return InductorPass.hash_dict(state)

View File

@@ -86,27 +86,36 @@ class PiecewiseBackend:
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
for size in self.compile_sizes:
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
if self.compile_sizes is not None:
for size in self.compile_sizes:
if isinstance(size, str):
assert size == "cudagraph_capture_sizes"
raise NotImplementedError(
"cudagraph_capture_sizes not supported in compile_sizes."
"This should be handled in `post_init_cudagraph_sizes`."
)
else:
assert isinstance(size, int)
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
def check_for_ending_compilation(self):
def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
def _fakify_args(self, args: list[Any]) -> list[Any]:
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
@@ -127,7 +136,9 @@ class PiecewiseBackend:
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any:
def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
if not range_entry.compiled:
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
@@ -136,14 +147,14 @@ class PiecewiseBackend:
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args = (
args_list = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else args
else list(args)
)
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
@@ -153,10 +164,13 @@ class PiecewiseBackend:
self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if self.compile_sizes is None:
return None
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else:
@@ -165,7 +179,7 @@ class PiecewiseBackend:
return self.range_entries[range]
return None
def __call__(self, *args) -> Any:
def __call__(self, *args: Any) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)

View File

@@ -4,9 +4,10 @@
import os
import sys
from abc import abstractmethod
from collections.abc import Callable, Generator
from contextlib import contextmanager, nullcontext
from types import CodeType
from typing import Any
from typing import Any, ParamSpec, TypeVar
import torch
import torch._C._dynamo.guards
@@ -19,19 +20,26 @@ from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__)
R = TypeVar("R")
P = ParamSpec("P")
def _noop_add_global_state_guard(self, *args, **kwargs):
def _noop_add_global_state_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the GLOBAL_STATE guard entirely"""
pass
def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
def _noop_add_torch_function_mode_stack_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass
@contextmanager
def _compilation_context():
def _compilation_context() -> Generator[None, None, None]:
"""Context manager for compilation settings and patches.
This manager:
@@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
"""
def check_invariants_and_forward(self, *args, **kwargs):
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
assert hasattr(self, "_check_shape_invariants")
self._check_shape_invariants(*args, **kwargs)
return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
def _call_with_optional_nvtx_range(
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Any:
if self.layerwise_nvtx_tracing_enabled:
args_list = list(args)
kwargs_dict = dict(kwargs)
@@ -108,7 +118,7 @@ class TorchCompileWithNoGuardsWrapper:
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self):
def __init__(self) -> None:
self.compiled = False
vllm_config = get_current_vllm_config()
@@ -192,9 +202,9 @@ class TorchCompileWithNoGuardsWrapper:
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
self._compiled_bytecode = None
self._compiled_bytecode: CodeType | None = None
def aot_compile(self, *args, **kwargs):
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self._compiled_callable, "aot_compile"):
raise RuntimeError(
"aot_compile is not supported by the current configuration. "
@@ -203,7 +213,7 @@ class TorchCompileWithNoGuardsWrapper:
)
return self._compiled_callable.aot_compile((args, kwargs))
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if envs.VLLM_USE_BYTECODE_HOOK:
if (
self.vllm_config.compilation_config.mode
@@ -236,13 +246,13 @@ class TorchCompileWithNoGuardsWrapper:
)
@abstractmethod
def forward(self, *args, **kwargs): ...
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
def original_code_object(self) -> CodeType:
"""Return the original code object of the forward method."""
return self.__class__.forward.__code__
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object():
return
@@ -299,7 +309,7 @@ class TorchCompileWithNoGuardsWrapper:
raise RuntimeError(msg)
@contextmanager
def _dispatch_to_compiled_code(self):
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
# noqa: E501
"""
Context manager to dispatch to internally compiled code for torch<2.8.