Update deprecated type hinting in model_loader (#18130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,8 @@ import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -49,21 +50,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
super().__init__(load_config)
|
||||
|
||||
# Save the module names without sharding.
|
||||
self.unsharded_weights_modules: List[str] = []
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: List[str] = []
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: List[str] = []
|
||||
self.target_modules: list[str] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: List[str],
|
||||
allowed_patterns: list[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[str, List[str], str]:
|
||||
) -> tuple[str, list[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
@@ -95,7 +96,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
f"No model weights found in: `{model_name_or_path}`")
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> Tuple[List[str], bool]:
|
||||
revision: Optional[str]) -> tuple[list[str], bool]:
|
||||
"""Prepare weight files for the model."""
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
@@ -155,7 +156,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
revision: Optional[str],
|
||||
pre_quant: bool,
|
||||
load_8bit: bool,
|
||||
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
|
||||
Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
as well as the quantization state dictionary."""
|
||||
@@ -175,7 +176,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision)
|
||||
|
||||
quant_state_dict: Dict[str, Any] = {}
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if pre_quant:
|
||||
if load_8bit:
|
||||
@@ -257,7 +258,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
temp_state_dict: Dict) -> QuantState:
|
||||
temp_state_dict: dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
@@ -415,7 +416,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
self._get_bnb_target_modules(model)
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
@@ -480,7 +481,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
Reference in New Issue
Block a user