Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC):
|
||||
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
method_class: type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
@@ -68,7 +68,7 @@ class QuantizationConfig(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||
self.packed_modules_mapping: dict[str, list[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
@@ -76,7 +76,7 @@ class QuantizationConfig(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -93,13 +93,13 @@ class QuantizationConfig(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -115,7 +115,7 @@ class QuantizationConfig(ABC):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
@@ -124,7 +124,7 @@ class QuantizationConfig(ABC):
|
||||
"quantization config.")
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
|
||||
def get_from_keys_or(config: dict[str, Any], keys: list[str],
|
||||
default: Any) -> Any:
|
||||
"""Get a optional value from the model's quantization config."""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user