Update deprecated type hinting in model_loader (#18130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 12:00:21 +01:00
committed by GitHub
parent a9944aabfa
commit 07ad27121f
12 changed files with 80 additions and 74 deletions

View File

@@ -5,7 +5,7 @@ import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type
from typing import Optional
import torch
import transformers
@@ -124,7 +124,7 @@ def device_loading_context(module: torch.nn.Module,
yield module
return
original_device_states: Dict[str, torch.device] = {}
original_device_states: dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
@@ -214,7 +214,7 @@ def resolve_transformers_arch(model_config: ModelConfig,
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
@@ -257,8 +257,8 @@ class ParamMapping:
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: Dict[str, List[str]]
inverse_packed_mapping: Dict[str, Tuple[str,
packed_mapping: dict[str, list[str]]
inverse_packed_mapping: dict[str, tuple[str,
int]] = field(default_factory=dict)
def __post_init__(self):
@@ -273,7 +273,7 @@ class ParamMapping:
)
def get_sub_modules(self,
module_name: str) -> Optional[Tuple[str, List[str]]]:
module_name: str) -> Optional[tuple[str, list[str]]]:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
@@ -281,7 +281,7 @@ class ParamMapping:
def configure_quant_config(quant_config: QuantizationConfig,
model_class: Type[nn.Module]):
model_class: type[nn.Module]):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules