Update deprecated type hinting in model_loader (#18130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -124,7 +124,7 @@ def device_loading_context(module: torch.nn.Module,
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: Dict[str, torch.device] = {}
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
@@ -214,7 +214,7 @@ def resolve_transformers_arch(model_config: ModelConfig,
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
# Special handling for quantized Mixtral.
|
||||
@@ -257,8 +257,8 @@ class ParamMapping:
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
packed_mapping: Dict[str, List[str]]
|
||||
inverse_packed_mapping: Dict[str, Tuple[str,
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str,
|
||||
int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -273,7 +273,7 @@ class ParamMapping:
|
||||
)
|
||||
|
||||
def get_sub_modules(self,
|
||||
module_name: str) -> Optional[Tuple[str, List[str]]]:
|
||||
module_name: str) -> Optional[tuple[str, list[str]]]:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
@@ -281,7 +281,7 @@ class ParamMapping:
|
||||
|
||||
|
||||
def configure_quant_config(quant_config: QuantizationConfig,
|
||||
model_class: Type[nn.Module]):
|
||||
model_class: type[nn.Module]):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Reference in New Issue
Block a user