[Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776)

This commit is contained in:
chenqianfzh
2024-06-01 13:51:10 -07:00
committed by GitHub
parent 37464a0f74
commit b9c0605a8e
11 changed files with 752 additions and 8 deletions

View File

@@ -92,6 +92,8 @@ class EngineArgs:
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
qlora_adapter_name_or_path: Optional[str] = None
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
@@ -159,7 +161,8 @@ class EngineArgs:
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
@@ -173,7 +176,9 @@ class EngineArgs:
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n')
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
'--dtype',
type=str,
@@ -543,7 +548,10 @@ class EngineArgs:
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one.")
parser.add_argument('--qlora-adapter-name-or-path',
type=str,
default=None,
help='Name or path of the QLoRA adapter.')
return parser
@classmethod
@@ -555,6 +563,23 @@ class EngineArgs:
return engine_args
def create_engine_config(self, ) -> EngineConfig:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
@@ -622,6 +647,13 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,