[Misc][BE] Type coverage for vllm/compilation [1/3] (#31554)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import operator
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
@@ -90,7 +90,7 @@ class CompilerManager:
|
||||
support int as key.
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
def __init__(self, compilation_config: CompilationConfig) -> None:
|
||||
self.cache: dict[tuple[Range, int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
@@ -100,7 +100,7 @@ class CompilerManager:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
|
||||
@contextmanager
|
||||
def compile_context(self, compile_range: Range):
|
||||
def compile_context(self, compile_range: Range) -> Generator[None, None, None]:
|
||||
"""Provide compilation context for the duration of compilation to set
|
||||
any torch global properties we want to scope to a single Inductor
|
||||
compilation (e.g. partition rules, pass context)."""
|
||||
@@ -115,7 +115,7 @@ class CompilerManager:
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache directory for the compiler.
|
||||
|
||||
@@ -143,7 +143,7 @@ class CompilerManager:
|
||||
# do not use eval(), it is unsafe.
|
||||
cache = ast.literal_eval(f.read())
|
||||
|
||||
def check_type(value, ty):
|
||||
def check_type(value: Any, ty: type) -> None:
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
|
||||
|
||||
@@ -165,7 +165,7 @@ class CompilerManager:
|
||||
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
||||
)
|
||||
|
||||
def save_to_file(self):
|
||||
def save_to_file(self) -> None:
|
||||
if self.disable_cache or not self.is_cache_updated:
|
||||
return
|
||||
printer = pprint.PrettyPrinter(indent=4)
|
||||
@@ -198,7 +198,7 @@ class CompilerManager:
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs,
|
||||
example_inputs: list[Any],
|
||||
additional_inductor_config,
|
||||
compilation_config: CompilationConfig,
|
||||
compile_range: Range,
|
||||
@@ -373,7 +373,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
compile_submod_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
vllm_backend: "VllmBackend",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
@@ -385,7 +385,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||
self.extra_traceback = False
|
||||
|
||||
def run(self, *args):
|
||||
def run(self, *args: Any) -> Any:
|
||||
# maybe instead just assert inputs are fake?
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
@@ -467,7 +467,7 @@ model_is_encoder: bool = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_tag(tag: str, is_encoder: bool = False):
|
||||
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
|
||||
"""Context manager to set the model tag."""
|
||||
global model_tag
|
||||
global model_is_encoder
|
||||
@@ -521,7 +521,7 @@ class VllmBackend:
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
is_encoder: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
# if the model is initialized with a non-empty prefix,
|
||||
# then usually it's enough to use that prefix,
|
||||
# e.g. language_model, vision_model, etc.
|
||||
@@ -558,7 +558,7 @@ class VllmBackend:
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def configure_post_pass(self):
|
||||
def configure_post_pass(self) -> None:
|
||||
self.pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
@@ -580,7 +580,7 @@ class VllmBackend:
|
||||
self.inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def __call__(
|
||||
self, graph: fx.GraphModule, example_inputs
|
||||
self, graph: fx.GraphModule, example_inputs: Sequence[Any]
|
||||
) -> VllmSerializableFunction:
|
||||
vllm_config = self.vllm_config
|
||||
# Minimal hashing here with existing utilities, reused below.
|
||||
|
||||
@@ -50,7 +50,7 @@ if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str):
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
@@ -637,7 +637,7 @@ class AllReduceRMSNormPattern(BasePattern):
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
super().__init__(dtype, device)
|
||||
@@ -692,7 +692,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
super().__init__(dtype, device)
|
||||
@@ -759,7 +759,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
super().__init__(dtype, device)
|
||||
@@ -828,7 +828,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
super().__init__(dtype, device)
|
||||
@@ -902,7 +902,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
super().__init__(dtype, device)
|
||||
@@ -988,7 +988,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
):
|
||||
super().__init__(dtype, device)
|
||||
|
||||
@@ -31,7 +31,7 @@ class CompilerInterface:
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
@@ -66,7 +66,7 @@ class CompilerInterface:
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a range. The `compile_range` specifies the range of the inputs,
|
||||
@@ -100,7 +100,7 @@ class CompilerInterface:
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]):
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
def hijack_load(*args: Any, **kwargs: Any) -> Any:
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, compile_range: Range):
|
||||
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
|
||||
if compile_range.is_single_size():
|
||||
# for a specific batch size, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
|
||||
)
|
||||
|
||||
|
||||
def set_functorch_config():
|
||||
def set_functorch_config() -> None:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -34,7 +36,7 @@ class CompilationCounter:
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs):
|
||||
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
|
||||
@@ -219,6 +219,7 @@ class CUDAGraphWrapper:
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
|
||||
|
||||
@@ -7,10 +7,11 @@ from collections.abc import Iterable, Iterator
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch.fx.node import Target
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
def is_func(node: fx.Node, target: Target) -> bool:
|
||||
return bool(node.op == "call_function" and node.target == target)
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
|
||||
@@ -8,9 +8,9 @@ import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
@@ -30,6 +30,8 @@ else:
|
||||
)
|
||||
|
||||
_pass_context = None
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class PassContext:
|
||||
@@ -44,7 +46,7 @@ def get_pass_context() -> PassContext:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(compile_range: Range):
|
||||
def pass_context(compile_range: Range) -> Generator[None, None, None]:
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
@@ -57,7 +59,7 @@ def pass_context(compile_range: Range):
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass):
|
||||
class InductorPass(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
@@ -73,7 +75,7 @@ class InductorPass(CustomGraphPass):
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: str | Any):
|
||||
def hash_source(*srcs: str | Any) -> str:
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
@@ -93,7 +95,7 @@ class InductorPass(CustomGraphPass):
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]):
|
||||
def hash_dict(dict_: dict[Any, Any]) -> str:
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
@@ -101,7 +103,7 @@ class InductorPass(CustomGraphPass):
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range):
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass):
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None):
|
||||
def __init__(
|
||||
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
|
||||
) -> None:
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args, **kwargs) -> Any:
|
||||
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.time()
|
||||
|
||||
@@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info_once(
|
||||
@@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled():
|
||||
def validate_cudagraph_capturing_enabled() -> None:
|
||||
# used to monitor whether a cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
@@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled():
|
||||
)
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool):
|
||||
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
|
||||
@@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inductor_partition_rule_context(splitting_ops: list[str]):
|
||||
def inductor_partition_rule_context(
|
||||
splitting_ops: list[str] | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to temporarily register Inductor partition rules.
|
||||
|
||||
Registers custom partition rules for specified operators, forcing the
|
||||
|
||||
@@ -41,8 +41,8 @@ class _SequenceParallelPatternHelper:
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
device: str | None,
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
@@ -64,7 +64,7 @@ class _SequenceParallelPatternHelper:
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
@@ -74,7 +74,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
return [input, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
@@ -100,7 +100,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
@@ -162,7 +162,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
device: str | None,
|
||||
):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
@@ -203,7 +203,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, NoReturn
|
||||
|
||||
import torch
|
||||
|
||||
@@ -29,14 +29,14 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition)
|
||||
Return None to skip inductor code caching entirely.
|
||||
"""
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> Any | None:
|
||||
"""
|
||||
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
|
||||
to enable subclasses to only have to implement uuid.
|
||||
"""
|
||||
return self.uuid()
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: Any) -> NoReturn:
|
||||
raise ValueError(
|
||||
"Cannot unpickle CustomGraphPass because pickling"
|
||||
" is used for cache key uuid. Use torch>=2.6 with"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
@@ -43,13 +44,17 @@ class VllmInductorPass(InductorPass):
|
||||
)
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
self.device: str | None = (
|
||||
config.device_config.device if config.device_config else None
|
||||
)
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(call_fn):
|
||||
def time_and_log(
|
||||
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
|
||||
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph):
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
@@ -58,17 +63,17 @@ class VllmInductorPass(InductorPass):
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(
|
||||
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
|
||||
)
|
||||
|
||||
def begin(self):
|
||||
def begin(self) -> None:
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self):
|
||||
def end_and_log(self) -> None:
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
@@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
return str(
|
||||
self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass):
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
@@ -165,9 +172,9 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
def __init__(self, name: str, config: VllmConfig):
|
||||
def __init__(self, name: str, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.dump_graph(graph, self.name)
|
||||
|
||||
Reference in New Issue
Block a user