[Misc][BE] Type coverage for vllm/compilation [1/3] (#31554)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-01-06 17:37:51 -08:00
committed by GitHub
parent 6f351548b2
commit 873480d133
12 changed files with 103 additions and 85 deletions

View File

@@ -31,7 +31,7 @@ class CompilerInterface:
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
@@ -66,7 +66,7 @@ class CompilerInterface:
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs,
@@ -100,7 +100,7 @@ class CompilerInterface:
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
def __init__(self) -> None:
self.guards: list[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
return True
def get_pruned_guards(self, *args, **kwargs):
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
return []
def produce_guards_expression(self, *args, **kwargs):
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
return ""
@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]):
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str:
@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
self.cache_dir = cache_dir
def compile(
@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args):
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx
@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
def hijack_load(*args, **kwargs):
def hijack_load(*args: Any, **kwargs: Any) -> Any:
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
# function renamed in 2.6
original_load_name = None
def hijacked_compile_fx_inner(*args, **kwargs):
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output
def hijack_compiled_fx_graph_hash(*args, **kwargs):
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out
def _check_can_cache(*args, **kwargs):
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run
def compiled_graph(*args):
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager:
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
else:
return contextlib.nullcontext()
def set_inductor_config(config, compile_range: Range):
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
)
def set_functorch_config():
def set_functorch_config() -> None:
torch._functorch.config.bundled_autograd_cache = False
@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_eager_compiles += 1
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.