Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Literal, Type, get_args
|
||||
from typing import Literal, get_args
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
@@ -76,7 +76,7 @@ def register_quantization_config(quantization: str):
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
|
||||
@@ -110,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
method_to_config: dict[str, Type[QuantizationConfig]] = {
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# and https://arxiv.org/pdf/2401.06118.pdf
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -98,7 +98,7 @@ def generic_dequantize_gemm(
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output_shape = input.shape[:-1] + (scales.shape[0], )
|
||||
@@ -136,7 +136,7 @@ def optimized_dequantize_gemm(
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
@@ -191,7 +191,7 @@ class AQLMConfig(QuantizationConfig):
|
||||
return "aqlm"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -199,11 +199,11 @@ class AQLMConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "AQLMConfig":
|
||||
in_group_size = cls.get_from_keys(config, ["in_group_size"])
|
||||
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
|
||||
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
|
||||
@@ -230,7 +230,7 @@ class AQLMLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
del output_size # Unused.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,7 +25,7 @@ class AWQConfig(QuantizationConfig):
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
@@ -48,7 +48,7 @@ class AWQConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -57,7 +57,7 @@ class AWQConfig(QuantizationConfig):
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
|
||||
@@ -65,7 +65,7 @@ class AWQConfig(QuantizationConfig):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
@@ -82,7 +82,7 @@ class AWQConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@@ -46,8 +46,8 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
@@ -79,7 +79,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -87,11 +87,11 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
@@ -150,7 +150,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
@@ -189,7 +189,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -48,7 +48,7 @@ class QuantizeMethodBase(ABC):
|
||||
|
||||
|
||||
def method_has_implemented_embedding(
|
||||
method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
method_class: type[QuantizeMethodBase]) -> bool:
|
||||
"""
|
||||
Not all quant methods have embedding implemented, so we need to check that
|
||||
it exists for our given method. We check this by making sure the function
|
||||
@@ -68,7 +68,7 @@ class QuantizationConfig(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||
self.packed_modules_mapping: dict[str, list[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
@@ -76,7 +76,7 @@ class QuantizationConfig(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -93,13 +93,13 @@ class QuantizationConfig(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -115,7 +115,7 @@ class QuantizationConfig(ABC):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
@@ -124,7 +124,7 @@ class QuantizationConfig(ABC):
|
||||
"quantization config.")
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys_or(config: Dict[str, Any], keys: List[str],
|
||||
def get_from_keys_or(config: dict[str, Any], keys: list[str],
|
||||
default: Any) -> Any:
|
||||
"""Get a optional value from the model's quantization config."""
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -105,7 +105,7 @@ class BitBLASConfig(QuantizationConfig):
|
||||
return "bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -114,12 +114,12 @@ class BitBLASConfig(QuantizationConfig):
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any],
|
||||
keys: List[str],
|
||||
def get_from_keys(config: dict[str, Any],
|
||||
keys: list[str],
|
||||
default: Any = None) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
@@ -128,7 +128,7 @@ class BitBLASConfig(QuantizationConfig):
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"], -1)
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"], False)
|
||||
@@ -193,7 +193,7 @@ class BitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
@@ -329,7 +329,7 @@ class BitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -29,7 +29,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[List[str]] = None,
|
||||
llm_int8_skip_modules: Optional[list[str]] = None,
|
||||
llm_int8_threshold: float = 6.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -61,7 +61,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return "bitsandbytes"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -69,13 +69,13 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"adapter_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
|
||||
def get_safe_value(config, keys, default_value=None):
|
||||
try:
|
||||
@@ -130,7 +130,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
# Split the prefix into its dot-separated components
|
||||
components = prefix.split('.')
|
||||
|
||||
@@ -169,7 +169,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
@@ -38,20 +38,20 @@ logger = init_logger(__name__)
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_scheme_map: Dict[str, Any],
|
||||
ignore: List[str],
|
||||
target_scheme_map: dict[str, Any],
|
||||
ignore: list[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: List[str],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: list[str],
|
||||
kv_cache_scheme: Optional[dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
@@ -66,7 +66,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -102,8 +102,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(
|
||||
config=config)
|
||||
@@ -121,8 +121,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def _parse_sparsity_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
|
||||
cls, config: dict[str, Any]
|
||||
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A tuple with two elements
|
||||
@@ -135,7 +135,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(
|
||||
sparsity_config)
|
||||
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
|
||||
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
|
||||
target: sparsity_config
|
||||
for target in sparsity_config.targets or list()
|
||||
}
|
||||
@@ -144,13 +144,13 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: Dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
target_scheme_map: Dict[str, Any] = dict()
|
||||
target_scheme_map: dict[str, Any] = dict()
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
@@ -188,7 +188,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
@@ -565,7 +565,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
@@ -611,7 +611,7 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache scheme. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
@@ -31,7 +31,7 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
quantized: bool = False,
|
||||
weight_quant: Optional[QuantizationArgs] = None,
|
||||
input_quant: Optional[QuantizationArgs] = None,
|
||||
model_compression_config: Optional[Dict[str, Any]] = None,
|
||||
model_compression_config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
@@ -53,7 +53,7 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
@@ -327,9 +327,9 @@ class CompressedTensors24(CompressedTensorsScheme):
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: List[torch.Tensor] = []
|
||||
split_bitmask: List[torch.Tensor] = []
|
||||
split_shape: List[Tuple[int, int]] = []
|
||||
split_weights: list[torch.Tensor] = []
|
||||
split_bitmask: list[torch.Tensor] = []
|
||||
split_shape: list[tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@@ -58,7 +58,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
layer.meta = Parameter(layer.meta.data, requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -26,7 +26,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@@ -58,7 +58,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@@ -90,7 +90,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
@@ -19,7 +19,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
||||
input_symmetric: bool):
|
||||
@@ -33,7 +33,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
@@ -35,7 +35,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
@@ -70,7 +70,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: List[int],
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -126,7 +126,7 @@ def triton_scaled_mm(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: Type[torch.dtype],
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Iterable, List, Mapping, Optional
|
||||
from typing import Optional
|
||||
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
@@ -20,7 +21,7 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
@@ -84,7 +85,7 @@ def find_matched_target(
|
||||
layer_name: Optional[str],
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
@@ -171,7 +172,7 @@ def _is_equal_or_regex_match(value: str,
|
||||
|
||||
def _match_fused_layer(
|
||||
layer_name: str, target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]]) -> Optional[str]:
|
||||
fused_mapping: Mapping[str, list[str]]) -> Optional[str]:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
@@ -201,7 +202,7 @@ def _match_fused_layer(
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: List[Optional[str]] = []
|
||||
unfused_matches: list[Optional[str]] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -46,7 +46,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
return "deepspeedfp"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits=weight_bits, group_size=group_size)
|
||||
@@ -55,7 +55,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
return DeepSpeedFPLinearMethod(self)
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -64,7 +64,7 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
@@ -91,7 +91,7 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
|
||||
def create_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,7 +25,7 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -33,11 +33,11 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -28,7 +28,7 @@ logger = init_logger(__name__)
|
||||
class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
|
||||
def __init__(self, ignore_list: List[str], input_scale_ub: float):
|
||||
def __init__(self, ignore_list: list[str], input_scale_ub: float):
|
||||
super().__init__()
|
||||
self.ignore_list = ignore_list if ignore_list else []
|
||||
self.input_scale_ub = input_scale_ub
|
||||
@@ -43,7 +43,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
return "fbgemm_fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
@@ -51,11 +51,11 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
|
||||
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
|
||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||
@@ -82,7 +82,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -57,8 +57,8 @@ class Fp8Config(QuantizationConfig):
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
weight_block_size: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
@@ -90,7 +90,7 @@ class Fp8Config(QuantizationConfig):
|
||||
return "fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -98,11 +98,11 @@ class Fp8Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
@@ -191,7 +191,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@@ -35,7 +35,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
@@ -43,11 +43,11 @@ class GGUFConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return [] # no extra configs.
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GGUFConfig":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
@@ -215,7 +215,7 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
self.params_dtype = params_dtype
|
||||
@@ -406,7 +406,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
|
||||
class GGUFUninitializedParameter(UninitializedParameter):
|
||||
cls_to_become = Parameter
|
||||
data_container: List[torch.Tensor]
|
||||
data_container: list[torch.Tensor]
|
||||
|
||||
def materialize_nested(self) -> Parameter:
|
||||
dtype = {data.dtype for data in self.data_container}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -34,11 +34,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
@@ -84,7 +84,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -93,11 +93,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
@@ -135,7 +135,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -129,7 +129,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
return "gptq_bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -137,11 +137,11 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
@@ -185,7 +185,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
return self.TORCH_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
@classmethod
|
||||
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
@@ -224,7 +224,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
"""
|
||||
|
||||
kernel_type = BitBLASLinearKernel
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
@@ -236,7 +236,7 @@ class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,8 +45,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool, lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
@@ -55,7 +55,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# Format is dict[str, dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
@@ -105,7 +105,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -113,11 +113,11 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
@@ -167,7 +167,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
@@ -199,7 +199,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
@@ -212,7 +212,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -90,7 +90,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -99,11 +99,11 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
@@ -146,7 +146,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -32,7 +32,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
skip_modules: Optional[List[str]] = None,
|
||||
skip_modules: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert group_size == 64, ("The only supported HQQ group size is "
|
||||
@@ -55,7 +55,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
return "hqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -63,11 +63,11 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig":
|
||||
wq_params = (config["quant_config"]["weight_quant_params"])
|
||||
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
|
||||
group_size = cls.get_from_keys(wq_params, ["group_size"])
|
||||
@@ -192,7 +192,7 @@ class HQQMarlinMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -32,7 +32,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
method: str,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
desc_act: Optional[bool] = None,
|
||||
lm_head_quantized: Optional[bool] = None,
|
||||
) -> None:
|
||||
@@ -63,7 +63,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
return "ipex"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
@@ -71,14 +71,14 @@ class IPEXConfig(QuantizationConfig):
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return [
|
||||
"quant_config.json",
|
||||
"quantize_config.json",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "IPEXConfig":
|
||||
method = cls.get_from_keys(config, ["quant_method"]).lower()
|
||||
if method == "awq":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,8 +12,8 @@ from vllm.scalar_type import ScalarType
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: Tuple[int, int] # [in, out]
|
||||
partition_weight_shape: Tuple[int, int]
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
@@ -31,7 +31,7 @@ class MPLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self,
|
||||
@@ -75,7 +75,7 @@ class MPLinearKernel(ABC):
|
||||
torch.nn.Parameter(new_param.data, requires_grad=False))
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> Tuple[
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
@@ -29,7 +29,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
|
||||
compute_capability: Optional[int] = None) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
@@ -46,7 +46,7 @@ def choose_mp_linear_kernel(
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
Type[MPLinearKernel]: Chosen kernel.
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,7 +22,7 @@ class AllSparkLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,10 +21,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING: bool = True
|
||||
MATMUL_LAYOUT: str = "nt"
|
||||
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
|
||||
BITBLAS_DTYPES: dict[torch.dtype, str] = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
@@ -103,7 +103,7 @@ class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
is_bitblas_installed = True
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,7 +25,7 @@ class ExllamaLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Exllama, "\
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,7 +25,7 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -24,7 +24,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -24,7 +24,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str,
|
||||
@@ -50,7 +50,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module) -> Tuple[
|
||||
self, layer: torch.nn.Module) -> tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
Optional[torch.Tensor], # input_scale,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel)
|
||||
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
@@ -27,7 +27,7 @@ _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: ScaledMMLinearLayerConfig,
|
||||
compute_capability: Optional[int] = None
|
||||
) -> Type[ScaledMMLinearKernel]:
|
||||
) -> type[ScaledMMLinearKernel]:
|
||||
"""
|
||||
Choose an ScaledMMLinearKernel that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
@@ -44,7 +44,7 @@ def choose_scaled_mm_linear_kernel(
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
Type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
|
||||
if compute_capability is None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,7 +20,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,7 +22,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if (not current_platform.is_cuda() and not current_platform.is_cpu()):
|
||||
return False, "CutlassScaledMM requires running on CUDA or CPU."
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,7 +18,7 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond # noqa: F401
|
||||
@@ -25,7 +25,7 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
|
||||
@classmethod
|
||||
def can_implement(
|
||||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
return False, "ScaledMMXLA requires running on TPU."
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -68,7 +68,7 @@ class MarlinConfig(QuantizationConfig):
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -77,11 +77,11 @@ class MarlinConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
@@ -128,7 +128,7 @@ class MarlinLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -53,7 +53,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return "modelopt"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -61,11 +61,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
@@ -107,7 +107,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
@@ -177,7 +177,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
self,
|
||||
is_checkpoint_nvfp4_serialized: bool,
|
||||
kv_cache_quant_algo: str,
|
||||
exclude_modules: List[str],
|
||||
exclude_modules: list[str],
|
||||
group_size: int = 16,
|
||||
) -> None:
|
||||
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
|
||||
@@ -195,7 +195,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return "nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
|
||||
|
||||
@classmethod
|
||||
@@ -203,11 +203,11 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config":
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if quant_method not in QUANT_ALGOS:
|
||||
@@ -227,7 +227,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: List):
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
||||
import re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
@@ -296,7 +296,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -23,8 +23,8 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, linear_quant_method: str, weight_bits: int,
|
||||
group_size: int, has_zp: bool, lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
modules_to_not_convert: Optional[list[str]],
|
||||
full_config: dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
@@ -69,7 +69,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -77,11 +77,11 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
|
||||
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
@@ -109,7 +109,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
|
||||
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return None
|
||||
|
||||
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
@@ -34,7 +34,7 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[str]:
|
||||
def get_supported_act_dtypes(self) -> list[str]:
|
||||
return SUPPORTED_QUANT_DTYPE_LIST
|
||||
|
||||
@classmethod
|
||||
@@ -43,11 +43,11 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
"This function should not be called with Neuron Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
|
||||
quantize_method = cls.get_from_keys(config, ["quantize_method"])
|
||||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
|
||||
return cls(dequant_dtype=dequant_dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -32,7 +32,7 @@ class PTPCFp8Config(Fp8Config):
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
ignored_layers: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
if not current_platform.is_rocm():
|
||||
raise ValueError(
|
||||
@@ -55,7 +55,7 @@ class PTPCFp8Config(Fp8Config):
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
|
||||
def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
return cls(activation_scheme=activation_scheme,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -89,7 +89,7 @@ class QQQConfig(QuantizationConfig):
|
||||
return "qqq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
@@ -97,7 +97,7 @@ class QQQConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
return [
|
||||
"quant_config.json",
|
||||
@@ -105,7 +105,7 @@ class QQQConfig(QuantizationConfig):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QQQConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "QQQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
@@ -131,7 +131,7 @@ class QQQLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import fnmatch
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
@@ -29,9 +29,9 @@ logger = init_logger(__name__)
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self,
|
||||
quant_config: Dict[str, Any],
|
||||
kv_cache_group: Optional[List[str]] = None,
|
||||
kv_cache_config: Optional[Dict[str, Any]] = None,
|
||||
quant_config: dict[str, Any],
|
||||
kv_cache_group: Optional[list[str]] = None,
|
||||
kv_cache_config: Optional[dict[str, Any]] = None,
|
||||
pack_method: str = "reorder"):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
@@ -44,7 +44,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -59,7 +59,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=exclude_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
@@ -78,12 +78,12 @@ class QuarkConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError("The export key should be included in "
|
||||
"the configurations of Quark quantized model")
|
||||
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
|
||||
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
@@ -95,7 +95,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(Dict[str, Any],
|
||||
layer_quant_config = cast(dict[str, Any],
|
||||
config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
@@ -108,7 +108,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
"configuration.")
|
||||
|
||||
q_configs = [
|
||||
cast(Dict[str, Any], layer_quant_config.get(name))
|
||||
cast(dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
]
|
||||
if not all(
|
||||
@@ -131,7 +131,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(Dict[str, Any],
|
||||
q_proj_q_config = cast(dict[str, Any],
|
||||
layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
@@ -142,7 +142,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
pack_method=pack_method)
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
@@ -162,8 +162,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
@@ -187,8 +187,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
@@ -209,8 +209,8 @@ class QuarkConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
logger.debug("Quark model is not in MX-FP4 format: "
|
||||
@@ -258,7 +258,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
return True
|
||||
|
||||
def _find_matched_config(self, layer_name: str,
|
||||
module: torch.nn.Module) -> Dict[str, Any]:
|
||||
module: torch.nn.Module) -> dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
@@ -283,29 +283,29 @@ class QuarkConfig(QuantizationConfig):
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
Dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
Dict[str, Any],
|
||||
dict[str, Any],
|
||||
self.quant_config.get("layer_type_quant_config"))
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
Dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported")
|
||||
weight_config = cast(Dict[str, Any], config.get("weight"))
|
||||
input_config = cast(Dict[str, Any], config.get("input_tensors"))
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
@@ -373,7 +373,7 @@ class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
@@ -417,7 +417,7 @@ class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache configuration. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,7 +45,7 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
|
||||
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
||||
Any]):
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -18,8 +18,8 @@ __all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
class QuarkW4A4MXFP4(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_quant_spec: Dict[str, Any],
|
||||
input_quant_spec: Dict[str, Any]):
|
||||
def __init__(self, weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any]):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
@@ -74,7 +74,7 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@@ -88,7 +88,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,7 +17,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
|
||||
input_symmetric: Optional[bool]):
|
||||
@@ -31,7 +31,7 @@ class QuarkW8A8Int8(QuarkScheme):
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Iterable, List, Mapping, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
@@ -21,7 +22,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
@@ -12,7 +12,7 @@ possible on ROCm), the model can be optionally augmented with KV cache
|
||||
scaling factors.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||
|
||||
@@ -23,7 +23,7 @@ class KVCacheQuantSchema(BaseModel):
|
||||
# layer indices to their per-tensor KV cache scaling factor.
|
||||
# TODO: Consider pulling this and its validation methods out into its
|
||||
# own schema class (tricky as its members are variable)
|
||||
scaling_factor: Dict[int, Dict[int, float]]
|
||||
scaling_factor: dict[int, dict[int, float]]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -24,7 +24,7 @@ class TorchAOConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -32,11 +32,11 @@ class TorchAOConfig(QuantizationConfig):
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
|
||||
"""Create the quant config from an hf model config"""
|
||||
try:
|
||||
from torchao.core.config import config_from_dict
|
||||
@@ -60,7 +60,7 @@ class TorchAOConfig(QuantizationConfig):
|
||||
return TorchAOLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class TorchAOLinearMethod(LinearMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -31,7 +31,7 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "tpu_int8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
@@ -40,11 +40,11 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
"This function should not be called with TPU Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
|
||||
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(activation_scheme=activation_scheme)
|
||||
|
||||
@@ -62,7 +62,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: Module, input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
|
||||
@@ -77,7 +77,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def _quantize_weight(
|
||||
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight_dtype = weight.dtype
|
||||
weight = weight.cpu().to(torch.float32)
|
||||
n_bit = 8
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -51,7 +51,7 @@ def _check_bitblas_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
@@ -133,7 +133,7 @@ def verify_bitblas_supports_shape(output_size_per_partition: int,
|
||||
def check_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
-> tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_bitblas_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
@@ -166,7 +166,7 @@ def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
|
||||
|
||||
def bitblas_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -32,7 +32,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@@ -95,7 +95,7 @@ def apply_w8a8_block_fp8_linear(
|
||||
def apply_w8a8_block_fp8_linear_fake(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -114,7 +114,7 @@ direct_register_custom_op(
|
||||
def input_to_float8(
|
||||
x: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to float8 values "
|
||||
"with tensor-wise quantization."""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
@@ -129,7 +129,7 @@ def input_to_float8(
|
||||
def block_quant_to_tensor_quant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function converts block-wise quantization to tensor-wise
|
||||
quantization. The inputs are block-wise quantization tensor `x_q_block`,
|
||||
block-wise quantization scale and the block size.
|
||||
@@ -247,7 +247,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
@@ -258,7 +258,7 @@ def per_token_group_quant_fp8(
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
@@ -412,7 +412,7 @@ def _w8a8_block_fp8_matmul(
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[Dict[int, Any]]:
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
The return value will be a dictionary that maps an irregular grid of
|
||||
@@ -452,7 +452,7 @@ def w8a8_block_fp8_matmul(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -52,7 +52,7 @@ def get_dynamic_override(
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool,
|
||||
None] = None) -> Union[Dict, int, bool, None]:
|
||||
None] = None) -> Union[dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
|
||||
@@ -5,7 +5,7 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
def apply_w8a8_block_int8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@@ -43,7 +43,7 @@ def apply_w8a8_block_int8_linear(
|
||||
|
||||
def input_to_int8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = torch.int8) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function quantizes input values to int8 values with
|
||||
tensor-wise quantization."""
|
||||
iinfo = torch.iinfo(dtype)
|
||||
@@ -58,7 +58,7 @@ def input_to_int8(
|
||||
def block_dequant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""This function conducts block-wise dequantization.
|
||||
The inputs are block-wise quantization tensor `x_q_block`,
|
||||
@@ -211,7 +211,7 @@ def per_token_group_quant_int8(
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = torch.int8,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
|
||||
It converts the tensor values into signed int8 values and returns the
|
||||
@@ -225,7 +225,7 @@ def per_token_group_quant_int8(
|
||||
is supported for now.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
assert (x.shape[-1] % group_size == 0
|
||||
@@ -358,7 +358,7 @@ def _w8a8_block_int8_matmul(
|
||||
|
||||
@functools.lru_cache
|
||||
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
|
||||
block_k: int) -> Optional[Dict[int, Any]]:
|
||||
block_k: int) -> Optional[dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the w8a8 block fp8 kernel.
|
||||
|
||||
@@ -399,7 +399,7 @@ def w8a8_block_int8_matmul(
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,19 +10,19 @@ MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
|
||||
|
||||
|
||||
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
|
||||
def query_machete_supported_quant_types(zero_points: bool) -> list[ScalarType]:
|
||||
if zero_points:
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
|
||||
def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def check_machete_supports_shape(in_features: int, out_featrues: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
-> tuple[bool, Optional[str]]:
|
||||
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
|
||||
return False, "Input features size must be divisible by "\
|
||||
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -70,7 +70,7 @@ def _check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
@@ -143,7 +143,7 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
|
||||
def check_marlin_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
-> tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_marlin_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
@@ -231,16 +231,16 @@ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
|
||||
|
||||
def marlin_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
def get_scale_perms():
|
||||
scale_perm: List[int] = []
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -64,9 +64,9 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -373,19 +372,19 @@ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
||||
|
||||
|
||||
def get_scale_perms_24():
|
||||
scale_perm: List[int] = []
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
||||
scale_perm_single: List[int] = []
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def get_weight_perm_24(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
col_o = col // 2
|
||||
for block in [0, 1]:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
@@ -34,10 +32,10 @@ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
|
||||
|
||||
|
||||
def get_qqq_scale_perms():
|
||||
scale_perm: List[int] = []
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
@@ -46,9 +44,9 @@ def get_qqq_scale_perms():
|
||||
|
||||
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
|
||||
def get_qqq_weight_perm(num_bits: int, quant_type: str):
|
||||
perm_list: List[int] = []
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -9,7 +8,7 @@ OCP_MX_BLOCK_SIZE = 32
|
||||
def per_token_group_quant_mxfp4(x: torch.Tensor,
|
||||
block_k: int,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
try:
|
||||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
||||
fake_quantize_fp4_fp6_per_group_with_scale)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -15,7 +16,7 @@ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
|
||||
int]):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
@@ -56,9 +57,9 @@ def group_broadcast(t, shape):
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: Tuple[int, int],
|
||||
group_shape: tuple[int, int],
|
||||
quant_dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, \
|
||||
"currently `scaled_quantize` only supports floating point dtypes " \
|
||||
@@ -97,9 +98,9 @@ def scaled_quantize(
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[Tuple[int, int]] = None,
|
||||
group_shape: Optional[tuple[int, int]] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
@@ -173,8 +174,8 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
ignored_layers: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -81,7 +81,7 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
||||
|
||||
def convert_to_channelwise(
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Create channelwise buffer
|
||||
weight_scale_channel = torch.empty((sum(logical_widths), 1),
|
||||
dtype=torch.float32,
|
||||
@@ -99,7 +99,7 @@ def convert_to_channelwise(
|
||||
|
||||
def requantize_with_max_scale(
|
||||
weight: torch.Tensor, weight_scale: torch.Tensor,
|
||||
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Max scale to be used for requanitzation.
|
||||
max_w_scale = weight_scale.max()
|
||||
|
||||
@@ -136,7 +136,7 @@ def maybe_create_device_identity():
|
||||
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
out_dtype: torch.dtype, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: List, **kwargs) -> torch.Tensor:
|
||||
output_shape: list, **kwargs) -> torch.Tensor:
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
@@ -154,7 +154,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output_shape: list) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi250_mi300
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
@@ -177,7 +177,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
@@ -198,7 +198,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output_shape: list) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
@@ -228,7 +228,7 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List,
|
||||
output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
@@ -384,7 +384,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert weight.dtype == torch.float8_e4m3fn
|
||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||
|
||||
Reference in New Issue
Block a user