[4/N] make quant config first-class citizen (#9978)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -23,9 +23,13 @@ if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
else:
|
||||
QuantizationConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1966,6 +1970,35 @@ class VllmConfig:
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_quant_config)
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} "
|
||||
"is not supported for the current GPU. Minimum "
|
||||
f"capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
@@ -1983,3 +2016,8 @@ class VllmConfig:
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
if self.quant_config is None and \
|
||||
self.model_config is not None and self.load_config is not None:
|
||||
self.quant_config = VllmConfig._get_quantization_config(
|
||||
self.model_config, self.load_config)
|
||||
|
||||
Reference in New Issue
Block a user