[Misc][BE] Type coverage for vllm/compilation [1/3] (#31554)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-01-06 17:37:51 -08:00
committed by GitHub
parent 6f351548b2
commit 873480d133
12 changed files with 103 additions and 85 deletions

View File

@@ -9,7 +9,7 @@ import operator
import os
import pprint
import time
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
@@ -90,7 +90,7 @@ class CompilerManager:
support int as key.
"""
def __init__(self, compilation_config: CompilationConfig):
def __init__(self, compilation_config: CompilationConfig) -> None:
self.cache: dict[tuple[Range, int, str], Any] = dict()
self.is_cache_updated = False
self.compilation_config = compilation_config
@@ -100,7 +100,7 @@ class CompilerManager:
return self.compiler.compute_hash(vllm_config)
@contextmanager
def compile_context(self, compile_range: Range):
def compile_context(self, compile_range: Range) -> Generator[None, None, None]:
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
@@ -115,7 +115,7 @@ class CompilerManager:
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
"""
Initialize the cache directory for the compiler.
@@ -143,7 +143,7 @@ class CompilerManager:
# do not use eval(), it is unsafe.
cache = ast.literal_eval(f.read())
def check_type(value, ty):
def check_type(value: Any, ty: type) -> None:
if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
@@ -165,7 +165,7 @@ class CompilerManager:
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
)
def save_to_file(self):
def save_to_file(self) -> None:
if self.disable_cache or not self.is_cache_updated:
return
printer = pprint.PrettyPrinter(indent=4)
@@ -198,7 +198,7 @@ class CompilerManager:
def compile(
self,
graph: fx.GraphModule,
example_inputs,
example_inputs: list[Any],
additional_inductor_config,
compilation_config: CompilationConfig,
compile_range: Range,
@@ -373,7 +373,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
compile_submod_names: list[str],
vllm_config: VllmConfig,
vllm_backend: "VllmBackend",
):
) -> None:
super().__init__(module)
from torch._guards import detect_fake_mode
@@ -385,7 +385,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
def run(self, *args):
def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
@@ -467,7 +467,7 @@ model_is_encoder: bool = False
@contextmanager
def set_model_tag(tag: str, is_encoder: bool = False):
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
"""Context manager to set the model tag."""
global model_tag
global model_is_encoder
@@ -521,7 +521,7 @@ class VllmBackend:
vllm_config: VllmConfig,
prefix: str = "",
is_encoder: bool = False,
):
) -> None:
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. language_model, vision_model, etc.
@@ -558,7 +558,7 @@ class VllmBackend:
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def configure_post_pass(self):
def configure_post_pass(self) -> None:
self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
@@ -580,7 +580,7 @@ class VllmBackend:
self.inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs
self, graph: fx.GraphModule, example_inputs: Sequence[Any]
) -> VllmSerializableFunction:
vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below.

View File

@@ -50,7 +50,7 @@ if hasattr(torch.ops._C, "scaled_fp4_quant"):
class BasePattern:
def __init__(self, dtype: torch.dtype, device: str):
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
@@ -637,7 +637,7 @@ class AllReduceRMSNormPattern(BasePattern):
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
@@ -692,7 +692,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
@@ -759,7 +759,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
@@ -828,7 +828,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
@@ -902,7 +902,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
@@ -988,7 +988,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)

View File

@@ -31,7 +31,7 @@ class CompilerInterface:
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
@@ -66,7 +66,7 @@ class CompilerInterface:
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs,
@@ -100,7 +100,7 @@ class CompilerInterface:
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
def __init__(self) -> None:
self.guards: list[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
return True
def get_pruned_guards(self, *args, **kwargs):
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
return []
def produce_guards_expression(self, *args, **kwargs):
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
return ""
@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]):
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str:
@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
self.cache_dir = cache_dir
def compile(
@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args):
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx
@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
def hijack_load(*args, **kwargs):
def hijack_load(*args: Any, **kwargs: Any) -> Any:
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
# function renamed in 2.6
original_load_name = None
def hijacked_compile_fx_inner(*args, **kwargs):
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args, **kwargs):
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args):
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager:
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
else:
return contextlib.nullcontext()
def set_inductor_config(config, compile_range: Range):
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
)
def set_functorch_config():
def set_functorch_config() -> None:
torch._functorch.config.bundled_autograd_cache = False
@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_eager_compiles += 1
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.

View File

@@ -3,7 +3,9 @@
import copy
import dataclasses
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
@dataclasses.dataclass
@@ -34,7 +36,7 @@ class CompilationCounter:
return copy.deepcopy(self)
@contextmanager
def expect(self, **kwargs):
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
old = self.clone()
yield
for k, v in kwargs.items():

View File

@@ -219,6 +219,7 @@ class CUDAGraphWrapper:
# runtime modes.
return self.runnable(*args, **kwargs)
assert batch_descriptor is not None
if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(

View File

@@ -7,10 +7,11 @@ from collections.abc import Iterable, Iterator
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload, OpOverloadPacket
from torch.fx.node import Target
def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
def is_func(node: fx.Node, target: Target) -> bool:
return bool(node.op == "call_function" and node.target == target)
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:

View File

@@ -8,9 +8,9 @@ import hashlib
import inspect
import json
import types
from collections.abc import Callable
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
import torch
from torch import fx
@@ -30,6 +30,8 @@ else:
)
_pass_context = None
P = ParamSpec("P")
R = TypeVar("R")
class PassContext:
@@ -44,7 +46,7 @@ def get_pass_context() -> PassContext:
@contextmanager
def pass_context(compile_range: Range):
def pass_context(compile_range: Range) -> Generator[None, None, None]:
"""A context manager that stores the current pass context,
usually it is a list of sizes to specialize.
"""
@@ -57,7 +59,7 @@ def pass_context(compile_range: Range):
_pass_context = prev_context
class InductorPass(CustomGraphPass):
class InductorPass(CustomGraphPass): # type: ignore[misc]
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
@@ -73,7 +75,7 @@ class InductorPass(CustomGraphPass):
return InductorPass.hash_source(self)
@staticmethod
def hash_source(*srcs: str | Any):
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
@@ -93,7 +95,7 @@ class InductorPass(CustomGraphPass):
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: dict[Any, Any]):
def hash_dict(dict_: dict[Any, Any]) -> str:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
@@ -101,7 +103,7 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_range(self, compile_range: Range):
def is_applicable_for_range(self, compile_range: Range) -> bool:
return True
@@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass):
implementation of the UUID.
"""
def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None):
def __init__(
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
) -> None:
self.callable = callable
self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)

View File

@@ -12,7 +12,7 @@ context_manager = None
torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(vllm_config: VllmConfig):
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
global torch_compile_start_time
torch_compile_start_time = time.time()
@@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
context_manager.__enter__()
def end_monitoring_torch_compile(vllm_config: VllmConfig):
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once(
@@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled():
def validate_cudagraph_capturing_enabled() -> None:
# used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error.
@@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled():
)
def set_cudagraph_capturing_enabled(enabled: bool):
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
global cudagraph_capturing_enabled
cudagraph_capturing_enabled = enabled

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from collections.abc import Generator
import torch
@@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
@contextlib.contextmanager
def inductor_partition_rule_context(splitting_ops: list[str]):
def inductor_partition_rule_context(
splitting_ops: list[str] | None,
) -> Generator[None, None, None]:
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the

View File

@@ -41,8 +41,8 @@ class _SequenceParallelPatternHelper:
self,
epsilon: float,
dtype: torch.dtype,
device: str,
):
device: str | None,
) -> None:
self.epsilon = epsilon
self.dtype = dtype
self.device = device
@@ -64,7 +64,7 @@ class _SequenceParallelPatternHelper:
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
@@ -74,7 +74,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass):
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor,
arg3_1: torch.Tensor,
@@ -100,7 +100,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
@@ -162,7 +162,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
@@ -203,7 +203,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, NoReturn
import torch
@@ -29,14 +29,14 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition)
Return None to skip inductor code caching entirely.
"""
def __getstate__(self):
def __getstate__(self) -> Any | None:
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
"""
return self.uuid()
def __setstate__(self, state):
def __setstate__(self, state: Any) -> NoReturn:
raise ValueError(
"Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"

View File

@@ -3,6 +3,7 @@
import functools
import operator
import time
from collections.abc import Callable
from dataclasses import dataclass
from typing import ClassVar
@@ -43,13 +44,17 @@ class VllmInductorPass(InductorPass):
)
self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.device: str | None = (
config.device_config.device if config.device_config else None
)
self.pass_name = self.__class__.__name__
@staticmethod
def time_and_log(call_fn):
def time_and_log(
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
@functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
self.begin()
self.dump_graph(graph, "before")
call_fn(self, graph)
@@ -58,17 +63,17 @@ class VllmInductorPass(InductorPass):
return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str):
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}"
lazy_format_graph_code(
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
)
def begin(self):
def begin(self) -> None:
self._start_time = time.perf_counter_ns()
def end_and_log(self):
def end_and_log(self) -> None:
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
@@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass):
def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations"""
return self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
return str(
self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string,
)
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
"""
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs.
@@ -165,9 +172,9 @@ class VllmPatternMatcherPass(VllmInductorPass):
class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig):
def __init__(self, name: str, config: VllmConfig) -> None:
super().__init__(config)
self.name = name
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: torch.fx.Graph) -> None:
self.dump_graph(graph, self.name)