Make mypy opt-out instead of opt-in (#33205)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-01-29 09:12:26 +00:00
committed by GitHub
parent a650ad1588
commit fb946a7f89
11 changed files with 35 additions and 57 deletions

View File

@@ -23,39 +23,8 @@ import sys
import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/entrypoints",
"vllm/executor",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/renderers",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/utils",
"vllm/worker",
"vllm/v1/attention",
"vllm/v1/core",
"vllm/v1/engine",
"vllm/v1/executor",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/structured_output",
"vllm/v1/worker",
]
# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
# from "skip" to "silent", remove its directory from SEPARATE_GROUPS.
SEPARATE_GROUPS = [
"tests",
# v0 related
@@ -74,6 +43,16 @@ EXCLUDE = [
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/v1/attention/ops",
# TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks",
"vllm/config",
"vllm/device_allocator",
"vllm/profiler",
"vllm/reasoning",
"vllm/tool_parser",
"vllm/v1/cudagraph_dispatcher.py",
"vllm/outputs.py",
"vllm/logger.py",
]
@@ -88,7 +67,6 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
A dictionary mapping file group names to lists of changed files.
"""
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files:
@@ -96,14 +74,13 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
if exclude_pattern.match(changed_file):
continue
# Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
else:
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
if changed_file.startswith("vllm/"):
file_groups[""].append(changed_file)
return file_groups

View File

@@ -349,7 +349,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None:
from aiter.mla import mla_decode_fwd
kwargs = {
kwargs: dict[str, float | torch.Tensor | None] = {
"sm_scale": sm_scale,
"logit_cap": logit_cap,
}

View File

@@ -570,7 +570,7 @@ class CompilationConfig:
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
max_cudagraph_capture_size: int | None = field(default=None)
max_cudagraph_capture_size: int = field(default=None)
"""The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest
@@ -743,6 +743,7 @@ class CompilationConfig:
"level",
"mode",
"cudagraph_mode",
"max_cudagraph_capture_size",
"use_inductor_graph_partition",
mode="wrap",
)

View File

@@ -9,7 +9,7 @@ import inspect
import json
import pathlib
import textwrap
from collections.abc import Callable, Iterable, Mapping, Sequence, Set
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
@@ -75,7 +75,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter(
object: object,
names: Iterable[str],
names: Sequence[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,

View File

@@ -382,7 +382,7 @@ def _patch_get_raw_stream_if_needed():
if hasattr(torch._C, "_cuda_getCurrentRawStream"):
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
builtins.get_raw_stream = _get_raw_stream
builtins.get_raw_stream = _get_raw_stream # type: ignore[attr-defined]
_patch_get_raw_stream_if_needed()

View File

@@ -1680,6 +1680,7 @@ def disable_envs_cache() -> None:
global __getattr__
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
if _is_envs_cache_enabled():
assert hasattr(__getattr__, "__wrapped__")
__getattr__ = __getattr__.__wrapped__

View File

@@ -270,7 +270,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):

View File

@@ -157,7 +157,7 @@ _METHODS_TO_PATCH = {
def _configure_vllm_root_logger() -> None:
logging_config = dict[str, Any]()
logging_config = dict[str, dict[str, Any] | Any]()
if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH:
raise RuntimeError(

View File

@@ -28,7 +28,7 @@ LogprobsOnePosition = dict[int, Logprob]
@dataclass
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
class FlatLogprobs(MutableSequence[LogprobsOnePosition | None]):
"""
Flat logprobs of a request into multiple primitive type lists.
@@ -140,7 +140,7 @@ class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlatLogprobs")
def insert(self, item) -> None:
def insert(self, index: int, value: dict[int, Logprob] | None) -> None:
raise TypeError("Cannot insert logprobs to FlatLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]:
@@ -161,7 +161,7 @@ SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlatLogprobs() if flat_logprobs else []
logprobs: PromptLogprobs = FlatLogprobs() if flat_logprobs else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs

View File

@@ -4,6 +4,7 @@
import dataclasses
from collections.abc import Callable
from _typeshed import DataclassInstance
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
#
@@ -11,7 +12,7 @@ from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
#
def trim_string_front(string, width):
def trim_string_front(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[offset:]
@@ -20,7 +21,7 @@ def trim_string_front(string, width):
return string
def trim_string_back(string, width):
def trim_string_back(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[:-offset]
@@ -30,15 +31,13 @@ def trim_string_back(string, width):
class TablePrinter:
def __init__(
self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int]
):
def __init__(self, row_cls: type[DataclassInstance], column_widths: dict[str, int]):
self.row_cls = row_cls
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
self.column_widths = column_widths
assert set(self.column_widths.keys()) == set(self.fieldnames)
def print_table(self, rows: list[dataclasses.dataclass]):
def print_table(self, rows: list[DataclassInstance]):
self._print_header()
self._print_line()
for row in rows:

View File

@@ -98,7 +98,7 @@ class FullAttentionSpec(AttentionSpec):
In this case, we use FullAttentionSpec and record the sliding window size.
"""
head_size_v: int | None = None
head_size_v: int = None # type: ignore[assignment]
sliding_window: int | None = None
"""