Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
@@ -29,7 +29,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_skip_modules: Optional[list[str]] = None,
llm_int8_threshold: float = 6.0,
) -> None:
super().__init__()
@@ -61,7 +61,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return "bitsandbytes"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
@@ -69,13 +69,13 @@ class BitsAndBytesConfig(QuantizationConfig):
return 70
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return [
"adapter_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
def get_safe_value(config, keys, default_value=None):
try:
@@ -130,7 +130,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return None
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
# Split the prefix into its dot-separated components
components = prefix.split('.')
@@ -169,7 +169,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
from bitsandbytes.nn import Int8Params