[Misc][BE] Turn on strict type coverage for vllm/compilation (#31756)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-01-22 07:12:26 -08:00
committed by GitHub
parent d117a4d1a9
commit 15e302dfce
11 changed files with 121 additions and 68 deletions

View File

@@ -100,6 +100,13 @@ ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "silent"
[[tool.mypy.overrides]]
module = "vllm.compilation.*"
disallow_untyped_defs = true
disallow_incomplete_defs = true
warn_return_any = true
follow_imports = "silent"
[tool.pytest.ini_options]
markers = [
"slow_test",

View File

@@ -28,7 +28,7 @@ def test_bad_callable():
pass_manager.configure(config)
with pytest.raises(AssertionError):
pass_manager.add(simple_callable)
pass_manager.add(simple_callable) # type: ignore[arg-type]
# Pass that inherits from InductorPass

View File

@@ -77,6 +77,11 @@ EXCLUDE = [
"vllm/v1/attention/ops",
]
# Directories that should be checked with --strict
STRICT_DIRS = [
"vllm/compilation",
]
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
@@ -108,11 +113,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
return file_groups
def is_strict_file(filepath: str) -> bool:
"""Check if a file should be checked with strict mode."""
return any(filepath.startswith(strict_dir) for strict_dir in STRICT_DIRS)
def mypy(
targets: list[str],
python_version: str | None,
follow_imports: str | None,
file_group: str,
strict: bool = False,
) -> int:
"""
Run mypy on the given targets.
@@ -124,6 +135,7 @@ def mypy(
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.
strict: If True, run mypy with --strict flag.
Returns:
The return code from mypy.
@@ -133,6 +145,8 @@ def mypy(
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
if strict:
args += ["--strict"]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode
@@ -149,8 +163,28 @@ def main():
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
# Separate files into strict and non-strict groups
strict_files = [f for f in changed_files if is_strict_file(f)]
non_strict_files = [f for f in changed_files if not is_strict_file(f)]
# Run mypy on non-strict files
if non_strict_files:
returncode |= mypy(
changed_files, python_version, follow_imports, file_group
non_strict_files,
python_version,
follow_imports,
file_group,
strict=False,
)
# Run mypy on strict files with --strict flag
if strict_files:
returncode |= mypy(
strict_files,
python_version,
follow_imports,
f"{file_group} (strict)",
strict=True,
)
return returncode

View File

@@ -68,7 +68,7 @@ def make_copy_and_call(
A wrapper function that copies inputs and calls the compiled function
"""
def copy_and_call(*args):
def copy_and_call(*args: Any) -> Any:
list_args = list(args)
for i, index in enumerate(sym_tensor_indices):
runtime_tensor = list_args[index]

View File

@@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts:
split on attn)
"""
def __init__(self):
def __init__(self) -> None:
# dict from submodule name to byte hash
self.submodule_bytes = {}
self.submodule_bytes: dict[str, str] = {}
# dict from byte hash to bytes
self.submodule_bytes_store = {}
self.submodule_bytes_store: dict[str, bytes] = {}
# dict from byte hash to loaded module
self.loaded_submodule_store = {}
self.loaded_submodule_store: dict[str, Any] = {}
def insert(self, submod_name: str, shape: str, entry: bytes):
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
hasher = hashlib.sha256()
hasher.update(entry)
hex_digest = hasher.hexdigest()
@@ -86,7 +86,7 @@ class StandaloneCompiledArtifacts:
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def get_loaded(self, submod_name: str, shape: str):
def get_loaded(self, submod_name: str, shape: str) -> Any:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
@@ -119,7 +119,7 @@ class StandaloneCompiledArtifacts:
from torch._inductor.standalone_compile import AOTCompiledArtifact
def _load_entry(entry_bytes) -> AOTCompiledArtifact:
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
entry = pickle.loads(entry_bytes)
return AOTCompiledArtifact.deserialize(entry)
@@ -132,13 +132,13 @@ class StandaloneCompiledArtifacts:
logger.debug("loaded all %s submodules", self.num_artifacts())
def __getstate__(self):
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
return {
"submodule_bytes": self.submodule_bytes,
"submodule_bytes_store": self.submodule_bytes_store,
}
def __setstate__(self, state):
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
self.submodule_bytes = state["submodule_bytes"]
self.submodule_bytes_store = state["submodule_bytes_store"]
self.loaded_submodule_store = {}
@@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
standalone_compile_artifacts.load_all()
submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable]] = {}
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
for cache_key in standalone_compile_artifacts.submodule_bytes:
submod_name, shape_str = cache_key.rsplit("_", 1)
@@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
return safe_hash(
result: str = safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
return result
def _compute_code_hash(files: set[str]) -> str:

View File

@@ -30,19 +30,15 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
flashinfer_comm: ModuleType | None = None
if find_spec("flashinfer"):
try:
import flashinfer.comm as flashinfer_comm
import flashinfer.comm as _flashinfer_comm
flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef]
flashinfer_comm
if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
else None
)
if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"):
flashinfer_comm = _flashinfer_comm
except ImportError:
flashinfer_comm = None # type: ignore[assignment]
else:
flashinfer_comm = None # type: ignore[assignment]
pass
logger = init_logger(__name__)
@@ -441,7 +437,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
):
return True
tp_size = get_tensor_model_parallel_world_size()
return compile_range.is_single_size() and compile_range.end % tp_size == 0
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
@@ -516,7 +512,7 @@ if flashinfer_comm is not None:
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, # type: ignore[arg-type]
device_capability, # type: ignore[arg-type, unused-ignore]
{},
).get(world_size, None)
# Use one shot if no max size is specified
@@ -666,6 +662,7 @@ class AllReduceRMSNormPattern(BasePattern):
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -722,6 +719,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -800,6 +798,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -875,6 +874,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -960,6 +960,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -1055,6 +1056,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@@ -1131,7 +1133,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
)
self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc]
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=self.max_token_num,
@@ -1204,7 +1206,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return compile_range.end <= self.max_token_num
return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:

View File

@@ -201,9 +201,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(
@@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str
def initialize_cache(

View File

@@ -45,10 +45,10 @@ logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=type[nn.Module])
_T = TypeVar("_T", bound=nn.Module)
def ignore_torch_compile(cls: _T) -> _T:
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
"""
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
@@ -68,7 +68,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return cls
def _should_ignore_torch_compile(cls: _T) -> bool:
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
@@ -79,21 +79,21 @@ def _should_ignore_torch_compile(cls: _T) -> bool:
def support_torch_compile(
*,
enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[_T], _T]: ...
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(
*,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
) -> Callable[[type[_T]], type[_T]]: ...
@overload
@@ -101,21 +101,21 @@ def support_torch_compile(
*,
dynamic_arg_dims: dict[str, int | list[int]] | None,
mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...
) -> Callable[[type[_T]], type[_T]]: ...
@overload
def support_torch_compile(cls: _T) -> _T: ...
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
def support_torch_compile(
cls: _T | None = None,
cls: type[_T] | None = None,
*,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[_T], _T] | _T:
) -> Callable[[type[_T]], type[_T]] | type[_T]:
"""
A decorator to add support for compiling the forward method of a class.
@@ -182,7 +182,7 @@ def support_torch_compile(
errors.
"""
def cls_decorator_helper(cls: _T) -> _T:
def cls_decorator_helper(cls: type[_T]) -> type[_T]:
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
# to avoid too much indentation for `_support_torch_compile`
if not hasattr(cls, "forward"):
@@ -263,12 +263,12 @@ def _verify_source_unchanged(
def _support_torch_compile(
cls: _T,
cls: type[_T],
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> _T:
) -> type[_T]:
"""
A decorator to add support for compiling the forward method of a class.
"""
@@ -325,12 +325,12 @@ def _support_torch_compile(
self.compiled = False
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type]
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
def _mark_dynamic_inputs(
mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any
mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
) -> None:
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
if ds_type == DynamicShapesType.UNBACKED:
@@ -382,7 +382,7 @@ def _support_torch_compile(
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self: _T, *args: Any, **kwargs: Any) -> Any:
def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
@@ -564,7 +564,7 @@ def _support_torch_compile(
return output
# triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self):
def save_aot_compiled_function(self: type[_T]) -> None:
if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save")
return

View File

@@ -141,7 +141,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
key: torch.Tensor | None,
cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return RotaryEmbedding.forward_static(
result: tuple[torch.Tensor, torch.Tensor | None] = (
RotaryEmbedding.forward_static(
positions,
query,
key,
@@ -150,6 +151,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
cos_sin_cache,
self.is_neox,
)
)
return result
class MatcherRMSNorm(MatcherCustomOp):
@@ -275,9 +278,10 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return RMSNorm.forward_static(
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
return result
class MatcherQuantFP8(MatcherCustomOp):

View File

@@ -25,7 +25,7 @@ logger = init_logger(__name__)
class RangeEntry:
compile_range: Range
compiled: bool = False
runnable: Callable = None # type: ignore
runnable: Callable[..., Any] = None # type: ignore
class PiecewiseBackend:
@@ -38,7 +38,7 @@ class PiecewiseBackend:
sym_shape_indices: list[int],
vllm_backend: VllmBackend,
returns_tuple: bool,
compiled_runnables: dict[str, Callable] | None = None,
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
):
"""
The backend for piecewise compilation.
@@ -138,8 +138,10 @@ class PiecewiseBackend:
self.on_compilation_complete = _on_compilation_complete_callback.get()
def get_compiled_graph_wrapper(self, compiled_graph):
def compiled_graph_wrapper(*args):
def get_compiled_graph_wrapper(
self, compiled_graph: Callable[..., Any]
) -> Callable[..., Any]:
def compiled_graph_wrapper(*args: Any) -> Any:
graph_output = compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
@@ -163,7 +165,7 @@ class PiecewiseBackend:
def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj):
def reducer_override(self, obj: object) -> Any:
if isinstance(obj, CachingAutotuner):
obj.prepare_for_pickle()
return pickle.loads, (
@@ -173,7 +175,7 @@ class PiecewiseBackend:
)
return NotImplemented
def serialize(fn) -> bytes:
def serialize(fn: Callable[..., Any]) -> bytes:
assert hasattr(fn, "serialize"), "fn must have serialize method"
with torch._functorch.config.patch("bundled_autograd_cache", True):
entry = fn.serialize()

View File

@@ -358,7 +358,10 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
):
return True
tp_size = get_tensor_model_parallel_world_size()
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
result: bool = (compile_range.is_single_size()) and (
compile_range.end % tp_size == 0
)
return result
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None: