Implement AWQ quantization support for LLaMA (#1032)
Co-authored-by: Robert Irvine <robert@seamlessml.com> Co-authored-by: root <rirv938@gmail.com> Co-authored-by: Casper <casperbh.96@gmail.com> Co-authored-by: julian-q <julianhquevedo@gmail.com>
This commit is contained in:
@@ -43,6 +43,8 @@ class ModelConfig:
|
||||
version.
|
||||
max_model_len: Maximum length of a sequence (including prompt and
|
||||
output). If None, will be derived from the model.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -57,6 +59,7 @@ class ModelConfig:
|
||||
seed: int,
|
||||
revision: Optional[str],
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@@ -66,11 +69,13 @@ class ModelConfig:
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
|
||||
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
self.max_model_len = None
|
||||
if max_model_len is not None:
|
||||
derived_max_model_len = self.get_max_model_len()
|
||||
@@ -100,6 +105,17 @@ class ModelConfig:
|
||||
"either 'auto' or 'slow'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq"]
|
||||
if self.quantization is None:
|
||||
return
|
||||
quantization = self.quantization.lower()
|
||||
if quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization: {self.quantization}. Must be one of "
|
||||
f"{supported_quantization}.")
|
||||
self.quantization = quantization
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
|
||||
Reference in New Issue
Block a user