diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 101a4833f..539261cf2 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -23,39 +23,8 @@ import sys import regex as re -FILES = [ - "vllm/*.py", - "vllm/assets", - "vllm/compilation", - "vllm/distributed", - "vllm/engine", - "vllm/entrypoints", - "vllm/executor", - "vllm/inputs", - "vllm/logging_utils", - "vllm/multimodal", - "vllm/platforms", - "vllm/plugins", - "vllm/renderers", - "vllm/tokenizers", - "vllm/transformers_utils", - "vllm/triton_utils", - "vllm/usage", - "vllm/utils", - "vllm/worker", - "vllm/v1/attention", - "vllm/v1/core", - "vllm/v1/engine", - "vllm/v1/executor", - "vllm/v1/metrics", - "vllm/v1/pool", - "vllm/v1/sample", - "vllm/v1/structured_output", - "vllm/v1/worker", -] - # After fixing errors resulting from changing follow_imports -# from "skip" to "silent", move the following directories to FILES +# from "skip" to "silent", remove its directory from SEPARATE_GROUPS. SEPARATE_GROUPS = [ "tests", # v0 related @@ -74,6 +43,16 @@ EXCLUDE = [ "vllm/model_executor/layers/fla/ops", # Ignore triton kernels in ops. "vllm/v1/attention/ops", + # TODO: Remove these entries after fixing mypy errors. + "vllm/benchmarks", + "vllm/config", + "vllm/device_allocator", + "vllm/profiler", + "vllm/reasoning", + "vllm/tool_parser", + "vllm/v1/cudagraph_dispatcher.py", + "vllm/outputs.py", + "vllm/logger.py", ] @@ -88,7 +67,6 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: A dictionary mapping file group names to lists of changed files. """ exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*") - files_pattern = re.compile(f"^({'|'.join(FILES)}).*") file_groups = {"": []} file_groups.update({k: [] for k in SEPARATE_GROUPS}) for changed_file in changed_files: @@ -96,14 +74,13 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: if exclude_pattern.match(changed_file): continue # Group files by mypy call - if files_pattern.match(changed_file): - file_groups[""].append(changed_file) - continue + for directory in SEPARATE_GROUPS: + if re.match(f"^{directory}.*", changed_file): + file_groups[directory].append(changed_file) + break else: - for directory in SEPARATE_GROUPS: - if re.match(f"^{directory}.*", changed_file): - file_groups[directory].append(changed_file) - break + if changed_file.startswith("vllm/"): + file_groups[""].append(changed_file) return file_groups diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 803ad8e93..1888d78ca 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -349,7 +349,7 @@ def _rocm_aiter_mla_decode_fwd_impl( ) -> None: from aiter.mla import mla_decode_fwd - kwargs = { + kwargs: dict[str, float | torch.Tensor | None] = { "sm_scale": sm_scale, "logit_cap": logit_cap, } diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index b6a172db2..d3e37f589 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -570,7 +570,7 @@ class CompilationConfig: pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" - max_cudagraph_capture_size: int | None = field(default=None) + max_cudagraph_capture_size: int = field(default=None) """The maximum cudagraph capture size. If cudagraph_capture_sizes is specified, this will be set to the largest @@ -743,6 +743,7 @@ class CompilationConfig: "level", "mode", "cudagraph_mode", + "max_cudagraph_capture_size", "use_inductor_graph_partition", mode="wrap", ) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 614373782..9288948c5 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -9,7 +9,7 @@ import inspect import json import pathlib import textwrap -from collections.abc import Callable, Iterable, Mapping, Sequence, Set +from collections.abc import Callable, Mapping, Sequence, Set from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -75,7 +75,7 @@ def get_field(cls: ConfigType, name: str) -> Field: def getattr_iter( object: object, - names: Iterable[str], + names: Sequence[str], default: Any | None = None, default_factory: Callable[[], Any] | None = None, warn: bool = False, diff --git a/vllm/env_override.py b/vllm/env_override.py index 8b2f2424d..e5a40dc3c 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -382,7 +382,7 @@ def _patch_get_raw_stream_if_needed(): if hasattr(torch._C, "_cuda_getCurrentRawStream"): from torch._C import _cuda_getCurrentRawStream as _get_raw_stream - builtins.get_raw_stream = _get_raw_stream + builtins.get_raw_stream = _get_raw_stream # type: ignore[attr-defined] _patch_get_raw_stream_if_needed() diff --git a/vllm/envs.py b/vllm/envs.py index ad220a979..1c9eacae1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1680,6 +1680,7 @@ def disable_envs_cache() -> None: global __getattr__ # If __getattr__ is wrapped by functions.cache, unwrap the caching layer. if _is_envs_cache_enabled(): + assert hasattr(__getattr__, "__wrapped__") __getattr__ = __getattr__.__wrapped__ diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 731c45fbb..20af24c2c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -270,7 +270,7 @@ def create_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, - slot_mapping: dict[str, torch.Tensor] | None = None, + slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): diff --git a/vllm/logger.py b/vllm/logger.py index e6e380794..2ec20003b 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -157,7 +157,7 @@ _METHODS_TO_PATCH = { def _configure_vllm_root_logger() -> None: - logging_config = dict[str, Any]() + logging_config = dict[str, dict[str, Any] | Any]() if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH: raise RuntimeError( diff --git a/vllm/logprobs.py b/vllm/logprobs.py index 6a820308f..cc77f5f7f 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -28,7 +28,7 @@ LogprobsOnePosition = dict[int, Logprob] @dataclass -class FlatLogprobs(MutableSequence[LogprobsOnePosition]): +class FlatLogprobs(MutableSequence[LogprobsOnePosition | None]): """ Flat logprobs of a request into multiple primitive type lists. @@ -140,7 +140,7 @@ class FlatLogprobs(MutableSequence[LogprobsOnePosition]): def __delitem__(self, item) -> None: raise TypeError("Cannot delete logprobs from FlatLogprobs") - def insert(self, item) -> None: + def insert(self, index: int, value: dict[int, Logprob] | None) -> None: raise TypeError("Cannot insert logprobs to FlatLogprobs") def __iter__(self) -> Iterator[LogprobsOnePosition]: @@ -161,7 +161,7 @@ SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs: """Creates a container to store prompt logprobs for a request""" - logprobs = FlatLogprobs() if flat_logprobs else [] + logprobs: PromptLogprobs = FlatLogprobs() if flat_logprobs else [] # NOTE: logprob of first prompt token is None. logprobs.append(None) return logprobs diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py index c95f9f4ac..1ef229f27 100644 --- a/vllm/profiler/utils.py +++ b/vllm/profiler/utils.py @@ -4,6 +4,7 @@ import dataclasses from collections.abc import Callable +from _typeshed import DataclassInstance from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata # @@ -11,7 +12,7 @@ from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata # -def trim_string_front(string, width): +def trim_string_front(string: str, width: int) -> str: if len(string) > width: offset = len(string) - width + 3 string = string[offset:] @@ -20,7 +21,7 @@ def trim_string_front(string, width): return string -def trim_string_back(string, width): +def trim_string_back(string: str, width: int) -> str: if len(string) > width: offset = len(string) - width + 3 string = string[:-offset] @@ -30,15 +31,13 @@ def trim_string_back(string, width): class TablePrinter: - def __init__( - self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int] - ): + def __init__(self, row_cls: type[DataclassInstance], column_widths: dict[str, int]): self.row_cls = row_cls self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.column_widths = column_widths assert set(self.column_widths.keys()) == set(self.fieldnames) - def print_table(self, rows: list[dataclasses.dataclass]): + def print_table(self, rows: list[DataclassInstance]): self._print_header() self._print_line() for row in rows: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 27c6f7da2..4a1b16fc5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -98,7 +98,7 @@ class FullAttentionSpec(AttentionSpec): In this case, we use FullAttentionSpec and record the sliding window size. """ - head_size_v: int | None = None + head_size_v: int = None # type: ignore[assignment] sliding_window: int | None = None """