Update deprecated type hinting in model_loader (#18130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 12:00:21 +01:00
committed by GitHub
parent a9944aabfa
commit 07ad27121f
12 changed files with 80 additions and 74 deletions

View File

@@ -6,7 +6,8 @@ import glob
import itertools
import math
import os
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from collections.abc import Generator
from typing import Any, Callable, Optional
import numpy as np
import torch
@@ -49,21 +50,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
super().__init__(load_config)
# Save the module names without sharding.
self.unsharded_weights_modules: List[str] = []
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: List[str] = []
self.column_sharded_weights_modules: list[str] = []
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: List[str] = []
self.target_modules: list[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: List[str],
allowed_patterns: list[str],
revision: Optional[str] = None,
) -> Tuple[str, List[str], str]:
) -> tuple[str, list[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
@@ -95,7 +96,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> Tuple[List[str], bool]:
revision: Optional[str]) -> tuple[list[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
@@ -155,7 +156,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision: Optional[str],
pre_quant: bool,
load_8bit: bool,
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
@@ -175,7 +176,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision)
quant_state_dict: Dict[str, Any] = {}
quant_state_dict: dict[str, Any] = {}
if pre_quant:
if load_8bit:
@@ -257,7 +258,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
temp_state_dict: dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
@@ -415,7 +416,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
self._get_bnb_target_modules(model)
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
@@ -480,7 +481,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter