[Misc][BE] Type coverage for vllm/compilation [1/3] (#31554)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -31,7 +31,7 @@ class CompilerInterface:
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
@@ -66,7 +66,7 @@ class CompilerInterface:
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a range. The `compile_range` specifies the range of the inputs,
|
||||
@@ -100,7 +100,7 @@ class CompilerInterface:
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args, **kwargs):
|
||||
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args, **kwargs):
|
||||
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]):
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
):
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
original_load = FxGraphCache.load
|
||||
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
|
||||
|
||||
def hijack_load(*args, **kwargs):
|
||||
def hijack_load(*args: Any, **kwargs: Any) -> Any:
|
||||
inductor_compiled_graph = original_load(*args, **kwargs)
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
# function renamed in 2.6
|
||||
original_load_name = None
|
||||
|
||||
def hijacked_compile_fx_inner(*args, **kwargs):
|
||||
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args, **kwargs):
|
||||
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args, **kwargs):
|
||||
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable:
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args):
|
||||
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager:
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, compile_range: Range):
|
||||
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
|
||||
if compile_range.is_single_size():
|
||||
# for a specific batch size, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
|
||||
)
|
||||
|
||||
|
||||
def set_functorch_config():
|
||||
def set_functorch_config() -> None:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
|
||||
Reference in New Issue
Block a user