Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
|
||||
f"out_group_size={self.out_group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "aqlm"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
@@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
@@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
or user_quant == "awq_marlin")
|
||||
|
||||
@@ -2,11 +2,16 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
@@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
|
||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
"""
|
||||
Detects if this quantization method can support a given checkpoint
|
||||
format by overriding the user specified quantization method --
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
@@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "bitblas"
|
||||
|
||||
@classmethod
|
||||
@@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
||||
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "bitsandbytes"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
@@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "compressed-tensors"
|
||||
|
||||
def get_quant_method(
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
f"group_size={self.group_size}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "DeepSpeedFP"
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "deepspeedfp"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
|
||||
@@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
@@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
self.fp8_linear = Fp8LinearOp()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fbgemm_fp8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
@@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fp8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
@@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
@@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_bitblas"
|
||||
|
||||
@classmethod
|
||||
@@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
@@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
@@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
lm_head_quantized, dynamic, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
|
||||
@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
@@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
self.quant_type, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
@@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
is_marlin_24_format = (
|
||||
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "hqq"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ipex"
|
||||
|
||||
@classmethod
|
||||
@@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig):
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
@@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig):
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
@@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig):
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
@@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
" the format is experimental and could change.")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "modelopt"
|
||||
|
||||
@classmethod
|
||||
@@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "modelopt_nvfp4"
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
@@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
lm_head_quantized, modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||
if can_convert and user_quant == "moe_wna16":
|
||||
return cls.get_name()
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
@@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
self.dequant_dtype = dequant_dtype
|
||||
self.quantize_method = quantize_method
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[str]:
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
@@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config):
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
@@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig):
|
||||
self.weight_bits, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "qqq"
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
@@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "quark"
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig):
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchAOConfig({self.torchao_config})"
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
@@ -7,6 +7,7 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
@@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "tpu_int8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
Reference in New Issue
Block a user