Implement AWQ quantization support for LLaMA (#1032)

Co-authored-by: Robert Irvine <robert@seamlessml.com>
Co-authored-by: root <rirv938@gmail.com>
Co-authored-by: Casper <casperbh.96@gmail.com>
Co-authored-by: julian-q <julianhquevedo@gmail.com>
This commit is contained in:
Woosuk Kwon
2023-09-16 00:03:37 -07:00
committed by GitHub
parent b9fe4616f9
commit e3e79e9e8a
19 changed files with 1178 additions and 208 deletions

View File

@@ -43,6 +43,8 @@ class ModelConfig:
version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
"""
def __init__(
@@ -57,6 +59,7 @@ class ModelConfig:
seed: int,
revision: Optional[str],
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@@ -66,11 +69,13 @@ class ModelConfig:
self.load_format = load_format
self.seed = seed
self.revision = revision
self.quantization = quantization
self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self.max_model_len = None
if max_model_len is not None:
derived_max_model_len = self.get_max_model_len()
@@ -100,6 +105,17 @@ class ModelConfig:
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
if self.quantization is None:
return
quantization = self.quantization.lower()
if quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization: {self.quantization}. Must be one of "
f"{supported_quantization}.")
self.quantization = quantization
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",