Make mypy opt-out instead of opt-in (#33205)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,39 +23,8 @@ import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
FILES = [
|
||||
"vllm/*.py",
|
||||
"vllm/assets",
|
||||
"vllm/compilation",
|
||||
"vllm/distributed",
|
||||
"vllm/engine",
|
||||
"vllm/entrypoints",
|
||||
"vllm/executor",
|
||||
"vllm/inputs",
|
||||
"vllm/logging_utils",
|
||||
"vllm/multimodal",
|
||||
"vllm/platforms",
|
||||
"vllm/plugins",
|
||||
"vllm/renderers",
|
||||
"vllm/tokenizers",
|
||||
"vllm/transformers_utils",
|
||||
"vllm/triton_utils",
|
||||
"vllm/usage",
|
||||
"vllm/utils",
|
||||
"vllm/worker",
|
||||
"vllm/v1/attention",
|
||||
"vllm/v1/core",
|
||||
"vllm/v1/engine",
|
||||
"vllm/v1/executor",
|
||||
"vllm/v1/metrics",
|
||||
"vllm/v1/pool",
|
||||
"vllm/v1/sample",
|
||||
"vllm/v1/structured_output",
|
||||
"vllm/v1/worker",
|
||||
]
|
||||
|
||||
# After fixing errors resulting from changing follow_imports
|
||||
# from "skip" to "silent", move the following directories to FILES
|
||||
# from "skip" to "silent", remove its directory from SEPARATE_GROUPS.
|
||||
SEPARATE_GROUPS = [
|
||||
"tests",
|
||||
# v0 related
|
||||
@@ -74,6 +43,16 @@ EXCLUDE = [
|
||||
"vllm/model_executor/layers/fla/ops",
|
||||
# Ignore triton kernels in ops.
|
||||
"vllm/v1/attention/ops",
|
||||
# TODO: Remove these entries after fixing mypy errors.
|
||||
"vllm/benchmarks",
|
||||
"vllm/config",
|
||||
"vllm/device_allocator",
|
||||
"vllm/profiler",
|
||||
"vllm/reasoning",
|
||||
"vllm/tool_parser",
|
||||
"vllm/v1/cudagraph_dispatcher.py",
|
||||
"vllm/outputs.py",
|
||||
"vllm/logger.py",
|
||||
]
|
||||
|
||||
|
||||
@@ -88,7 +67,6 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
|
||||
A dictionary mapping file group names to lists of changed files.
|
||||
"""
|
||||
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
|
||||
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
|
||||
file_groups = {"": []}
|
||||
file_groups.update({k: [] for k in SEPARATE_GROUPS})
|
||||
for changed_file in changed_files:
|
||||
@@ -96,14 +74,13 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
|
||||
if exclude_pattern.match(changed_file):
|
||||
continue
|
||||
# Group files by mypy call
|
||||
if files_pattern.match(changed_file):
|
||||
file_groups[""].append(changed_file)
|
||||
continue
|
||||
for directory in SEPARATE_GROUPS:
|
||||
if re.match(f"^{directory}.*", changed_file):
|
||||
file_groups[directory].append(changed_file)
|
||||
break
|
||||
else:
|
||||
for directory in SEPARATE_GROUPS:
|
||||
if re.match(f"^{directory}.*", changed_file):
|
||||
file_groups[directory].append(changed_file)
|
||||
break
|
||||
if changed_file.startswith("vllm/"):
|
||||
file_groups[""].append(changed_file)
|
||||
return file_groups
|
||||
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
kwargs = {
|
||||
kwargs: dict[str, float | torch.Tensor | None] = {
|
||||
"sm_scale": sm_scale,
|
||||
"logit_cap": logit_cap,
|
||||
}
|
||||
|
||||
@@ -570,7 +570,7 @@ class CompilationConfig:
|
||||
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||
"""Custom inductor passes, see PassConfig for more details"""
|
||||
|
||||
max_cudagraph_capture_size: int | None = field(default=None)
|
||||
max_cudagraph_capture_size: int = field(default=None)
|
||||
"""The maximum cudagraph capture size.
|
||||
|
||||
If cudagraph_capture_sizes is specified, this will be set to the largest
|
||||
@@ -743,6 +743,7 @@ class CompilationConfig:
|
||||
"level",
|
||||
"mode",
|
||||
"cudagraph_mode",
|
||||
"max_cudagraph_capture_size",
|
||||
"use_inductor_graph_partition",
|
||||
mode="wrap",
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ import inspect
|
||||
import json
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence, Set
|
||||
from collections.abc import Callable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
@@ -75,7 +75,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
|
||||
def getattr_iter(
|
||||
object: object,
|
||||
names: Iterable[str],
|
||||
names: Sequence[str],
|
||||
default: Any | None = None,
|
||||
default_factory: Callable[[], Any] | None = None,
|
||||
warn: bool = False,
|
||||
|
||||
@@ -382,7 +382,7 @@ def _patch_get_raw_stream_if_needed():
|
||||
if hasattr(torch._C, "_cuda_getCurrentRawStream"):
|
||||
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
|
||||
|
||||
builtins.get_raw_stream = _get_raw_stream
|
||||
builtins.get_raw_stream = _get_raw_stream # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_patch_get_raw_stream_if_needed()
|
||||
|
||||
@@ -1680,6 +1680,7 @@ def disable_envs_cache() -> None:
|
||||
global __getattr__
|
||||
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
|
||||
if _is_envs_cache_enabled():
|
||||
assert hasattr(__getattr__, "__wrapped__")
|
||||
__getattr__ = __getattr__.__wrapped__
|
||||
|
||||
|
||||
|
||||
@@ -270,7 +270,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
slot_mapping: dict[str, torch.Tensor] | None = None,
|
||||
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
|
||||
@@ -157,7 +157,7 @@ _METHODS_TO_PATCH = {
|
||||
|
||||
|
||||
def _configure_vllm_root_logger() -> None:
|
||||
logging_config = dict[str, Any]()
|
||||
logging_config = dict[str, dict[str, Any] | Any]()
|
||||
|
||||
if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -28,7 +28,7 @@ LogprobsOnePosition = dict[int, Logprob]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
class FlatLogprobs(MutableSequence[LogprobsOnePosition | None]):
|
||||
"""
|
||||
Flat logprobs of a request into multiple primitive type lists.
|
||||
|
||||
@@ -140,7 +140,7 @@ class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
def __delitem__(self, item) -> None:
|
||||
raise TypeError("Cannot delete logprobs from FlatLogprobs")
|
||||
|
||||
def insert(self, item) -> None:
|
||||
def insert(self, index: int, value: dict[int, Logprob] | None) -> None:
|
||||
raise TypeError("Cannot insert logprobs to FlatLogprobs")
|
||||
|
||||
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
||||
@@ -161,7 +161,7 @@ SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
||||
|
||||
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
|
||||
"""Creates a container to store prompt logprobs for a request"""
|
||||
logprobs = FlatLogprobs() if flat_logprobs else []
|
||||
logprobs: PromptLogprobs = FlatLogprobs() if flat_logprobs else []
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
logprobs.append(None)
|
||||
return logprobs
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
|
||||
from _typeshed import DataclassInstance
|
||||
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
|
||||
|
||||
#
|
||||
@@ -11,7 +12,7 @@ from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
|
||||
#
|
||||
|
||||
|
||||
def trim_string_front(string, width):
|
||||
def trim_string_front(string: str, width: int) -> str:
|
||||
if len(string) > width:
|
||||
offset = len(string) - width + 3
|
||||
string = string[offset:]
|
||||
@@ -20,7 +21,7 @@ def trim_string_front(string, width):
|
||||
return string
|
||||
|
||||
|
||||
def trim_string_back(string, width):
|
||||
def trim_string_back(string: str, width: int) -> str:
|
||||
if len(string) > width:
|
||||
offset = len(string) - width + 3
|
||||
string = string[:-offset]
|
||||
@@ -30,15 +31,13 @@ def trim_string_back(string, width):
|
||||
|
||||
|
||||
class TablePrinter:
|
||||
def __init__(
|
||||
self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int]
|
||||
):
|
||||
def __init__(self, row_cls: type[DataclassInstance], column_widths: dict[str, int]):
|
||||
self.row_cls = row_cls
|
||||
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
|
||||
self.column_widths = column_widths
|
||||
assert set(self.column_widths.keys()) == set(self.fieldnames)
|
||||
|
||||
def print_table(self, rows: list[dataclasses.dataclass]):
|
||||
def print_table(self, rows: list[DataclassInstance]):
|
||||
self._print_header()
|
||||
self._print_line()
|
||||
for row in rows:
|
||||
|
||||
@@ -98,7 +98,7 @@ class FullAttentionSpec(AttentionSpec):
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
"""
|
||||
|
||||
head_size_v: int | None = None
|
||||
head_size_v: int = None # type: ignore[assignment]
|
||||
|
||||
sliding_window: int | None = None
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user