[Mypy] Fix mypy for vllm/model_executor (except vllm/model_executor/layers) (#37904)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-24 17:14:01 +00:00
committed by GitHub
parent dc78c2c933
commit b3601da6e7
10 changed files with 44 additions and 39 deletions

View File

@@ -96,6 +96,7 @@ def sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
assert prefill_metadata is not None
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
@@ -170,6 +171,8 @@ def sparse_attn_indexer(
if has_decode:
decode_metadata = attn_metadata.decode
assert decode_metadata is not None
# kv_cache shape [
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Generator
from typing import TYPE_CHECKING, cast
import gguf
import regex as re
@@ -27,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
from vllm.utils.torch_utils import set_default_torch_dtype
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
logger = init_logger(__name__)
@@ -350,10 +354,9 @@ class GGUFModelLoader(BaseModelLoader):
for name, weight_type in weight_type_map.items()
if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
]
logger.debug(
"GGUF unquantized modules: %s",
unquant_names,
)
logger.debug("GGUF unquantized modules: %s", unquant_names)
if TYPE_CHECKING:
vllm_config.quant_config = cast(GGUFConfig, vllm_config.quant_config)
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
target_device = torch.device(device_config.device)

View File

@@ -27,28 +27,16 @@ class RunaiModelStreamerLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self._is_distributed = False
self._is_distributed: bool = False
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if "distributed" in extra_config and isinstance(
extra_config.get("distributed"), bool
):
self._is_distributed = extra_config.get("distributed")
if "concurrency" in extra_config and isinstance(
extra_config.get("concurrency"), int
):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency")
)
if "memory_limit" in extra_config and isinstance(
extra_config.get("memory_limit"), int
):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit")
)
if isinstance(distributed := extra_config.get("distributed"), bool):
self._is_distributed = distributed
if isinstance(concurrency := extra_config.get("concurrency"), int):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(concurrency)
if isinstance(memory_limit := extra_config.get("memory_limit"), int):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(memory_limit)
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
@@ -93,7 +81,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
return hf_weights_files
def _get_weights_iterator(
self, model_or_path: str, revision: str
self, model_or_path: str, revision: str | None
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)

View File

@@ -6,6 +6,7 @@ import glob
import os
import time
from collections.abc import Generator
from copy import copy
from typing import Any
import torch
@@ -42,7 +43,7 @@ class ShardedStateLoader(BaseModelLoader):
extra_config = (
{}
if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy()
else copy(load_config.model_loader_extra_config)
)
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:

View File

@@ -674,7 +674,8 @@ def serialize_vllm_model(
key = f.read()
encryption_params = EncryptionParams(key=key)
output_file = tensorizer_args.tensorizer_uri
if (output_file := tensorizer_args.tensorizer_uri) is None:
raise ValueError("tensorizer_uri must be specified for serialization.")
if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank

View File

@@ -121,6 +121,7 @@ class TensorizerLoader(BaseModelLoader):
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
assert self.tensorizer_config.tensorizer_uri is not None
self.tensorizer_config.tensorizer_uri = (
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
)

View File

@@ -6,6 +6,7 @@ import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import nn
@@ -71,7 +72,7 @@ def initialize_model(
model_class,
)
# try to be compatible with old-style model class
kwargs = {}
kwargs: dict[str, Any] = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:

View File

@@ -257,6 +257,8 @@ def convert_bin_to_safetensor_file(
def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig:
if model_config.quantization is None:
raise ValueError("Model quantization method is not specified in the config.")
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
@@ -307,6 +309,11 @@ def get_quant_config(
# if hf_quant_config is None, we will try to get config from
# hf_overrides
hf_overrides = model_config.hf_overrides
if not isinstance(hf_overrides, dict):
raise ValueError(
"hf_overrides must be a dict for get_quant_config "
"to get the quantization config from it."
)
quantization_config_file = hf_overrides.get("quantization_config_file", None)
if quantization_config_file is not None:
if hasattr(quant_cls, "from_config_file"):
@@ -1087,7 +1094,7 @@ def multi_thread_pt_weights_iterator(
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
gguf_file: str | Path, gguf_to_hf_name_map: dict[str, str]
) -> list[str]:
reader = gguf.GGUFReader(gguf_file)
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
@@ -1097,7 +1104,7 @@ def get_gguf_extra_tensor_names(
def get_gguf_weight_type_map(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
gguf_file: str | Path, gguf_to_hf_name_map: dict[str, str]
) -> dict[str, str]:
"""
Return GGUF mapped weight's name and its quant type
@@ -1111,7 +1118,7 @@ def get_gguf_weight_type_map(
def gguf_quant_weights_iterator(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
gguf_file: str | Path, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""
Iterate over the quant weights in the model gguf files and convert

View File

@@ -154,8 +154,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
self.data.copy_(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_offset: int = kwargs["shard_offset"]
shard_size: int = kwargs["shard_size"]
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if (
@@ -176,10 +176,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads")
shard_offset: int = kwargs["shard_offset"]
shard_size: int = kwargs["shard_size"]
shard_id: str = kwargs["shard_id"]
num_heads: int = kwargs["num_heads"]
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if (
@@ -191,10 +191,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
param_data = self.data
shard_id = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads
shard_id_int = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
self.output_dim, shard_id_int * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape