Improve configs - ModelConfig (#17130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-30 03:38:22 +01:00
committed by GitHub
parent 2c4f59afc3
commit 13698db634
36 changed files with 490 additions and 648 deletions

View File

@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
@@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
f"out_group_size={self.out_group_size})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "aqlm"
@classmethod

View File

@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
@@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
f"modules_to_not_convert={self.modules_to_not_convert})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "awq_marlin"
@classmethod
@@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "awq_marlin")

View File

@@ -2,11 +2,16 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch
from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
raise NotImplementedError
@@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
raise NotImplementedError
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --

View File

@@ -5,6 +5,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
@@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "bitblas"
@classmethod
@@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"

View File

@@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import direct_register_custom_op
@@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "bitsandbytes"
@classmethod

View File

@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
@@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "compressed-tensors"
def get_quant_method(

View File

@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
@@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
f"group_size={self.group_size}")
@classmethod
def get_name(cls) -> str:
return "DeepSpeedFP"
def get_name(cls) -> QuantizationMethods:
return "deepspeedfp"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":

View File

@@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
@@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
super().__init__()
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "experts_int8"
@classmethod

View File

@@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
self.fp8_linear = Fp8LinearOp()
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "fbgemm_fp8"
@classmethod

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
self.weight_block_size = weight_block_size
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "fp8"
@classmethod

View File

@@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
def __repr__(self) -> str:
return ("GGUFConfig()")
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
@@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
f"dynamic={self.dynamic}")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq"
@classmethod

View File

@@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
@@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq_bitblas"
@classmethod
@@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig):
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
@@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig):
f"dynamic={self.dynamic}")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq_marlin"
@classmethod
@@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized, dynamic, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"

View File

@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter,
@@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
self.quant_type, self.group_size)
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq_marlin_24"
@classmethod
@@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return cls(weight_bits, group_size)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24")

View File

@@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig):
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "hqq"
@classmethod

View File

@@ -6,6 +6,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
@@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig):
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "ipex"
@classmethod
@@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig):
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if not current_platform.is_cpu() and not current_platform.is_xpu():
return None

View File

@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig):
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "marlin"
@classmethod
@@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig):
return cls(group_size, lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"

View File

@@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig):
" the format is experimental and could change.")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "modelopt"
@classmethod
@@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
self.exclude_modules = exclude_modules
@classmethod
def get_name(cls) -> str:
return "modelopt_nvfp4"
def get_name(cls) -> QuantizationMethods:
return "nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:

View File

@@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig):
self.modules_to_not_convert = modules_to_not_convert
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "moe_wna16"
@classmethod
@@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig):
lm_head_quantized, modules_to_not_convert, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16":
return cls.get_name()

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from torch.nn import Module
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig):
self.dequant_dtype = dequant_dtype
self.quantize_method = quantize_method
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "neuron_quant"
def get_supported_act_dtypes(self) -> List[str]:

View File

@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
@@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config):
ignored_layers=ignored_layers)
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "ptpc_fp8"
@classmethod

View File

@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter,
@@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig):
self.weight_bits, self.group_size)
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "qqq"
@classmethod

View File

@@ -8,6 +8,7 @@ import torch
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig):
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "quark"
def get_quant_method(self, layer: torch.nn.Module,

View File

@@ -6,6 +6,7 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
@@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig):
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@@ -7,6 +7,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter
@@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig):
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]: