Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -2,7 +2,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch import nn
@@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC):
def method_has_implemented_embedding(
method_class: Type[QuantizeMethodBase]) -> bool:
method_class: type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
@@ -68,7 +68,7 @@ class QuantizationConfig(ABC):
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
self.packed_modules_mapping: dict[str, list[str]] = dict()
@abstractmethod
def get_name(self) -> QuantizationMethods:
@@ -76,7 +76,7 @@ class QuantizationConfig(ABC):
raise NotImplementedError
@abstractmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@@ -93,13 +93,13 @@ class QuantizationConfig(ABC):
@staticmethod
@abstractmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@@ -115,7 +115,7 @@ class QuantizationConfig(ABC):
return None
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
@@ -124,7 +124,7 @@ class QuantizationConfig(ABC):
"quantization config.")
@staticmethod
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
def get_from_keys_or(config: dict[str, Any], keys: list[str],
default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try: