Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
import torch.nn.functional as F
@@ -24,7 +24,7 @@ class TorchAOConfig(QuantizationConfig):
def get_name(self) -> QuantizationMethods:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
@@ -32,11 +32,11 @@ class TorchAOConfig(QuantizationConfig):
return 75
@staticmethod
def get_config_filenames() -> List[str]:
def get_config_filenames() -> list[str]:
return ["config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
"""Create the quant config from an hf model config"""
try:
from torchao.core.config import config_from_dict
@@ -60,7 +60,7 @@ class TorchAOConfig(QuantizationConfig):
return TorchAOLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
def get_scaled_act_names(self) -> list[str]:
return []
@@ -97,7 +97,7 @@ class TorchAOLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,