[Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776)
This commit is contained in:
@@ -92,6 +92,8 @@ class EngineArgs:
|
||||
ngram_prompt_lookup_max: Optional[int] = None
|
||||
ngram_prompt_lookup_min: Optional[int] = None
|
||||
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
@@ -159,7 +161,8 @@ class EngineArgs:
|
||||
type=str,
|
||||
default=EngineArgs.load_format,
|
||||
choices=[
|
||||
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
|
||||
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
|
||||
'bitsandbytes'
|
||||
],
|
||||
help='The format of the model weights to load.\n\n'
|
||||
'* "auto" will try to load the weights in the safetensors format '
|
||||
@@ -173,7 +176,9 @@ class EngineArgs:
|
||||
'which is mainly for profiling.\n'
|
||||
'* "tensorizer" will load the weights using tensorizer from '
|
||||
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
|
||||
'section for more information.\n')
|
||||
'section for more information.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
@@ -543,7 +548,10 @@ class EngineArgs:
|
||||
"will also be used in `model_name` tag content of "
|
||||
"prometheus metrics, if multiple names provided, metrics"
|
||||
"tag will take the first one.")
|
||||
|
||||
parser.add_argument('--qlora-adapter-name-or-path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Name or path of the QLoRA adapter.')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@@ -555,6 +563,23 @@ class EngineArgs:
|
||||
return engine_args
|
||||
|
||||
def create_engine_config(self, ) -> EngineConfig:
|
||||
|
||||
# bitsandbytes quantization needs a specific model loader
|
||||
# so we make sure the quant method and the load format are consistent
|
||||
if (self.quantization == "bitsandbytes" or
|
||||
self.qlora_adapter_name_or_path is not None) and \
|
||||
self.load_format != "bitsandbytes":
|
||||
raise ValueError(
|
||||
"BitsAndBytes quantization and QLoRA adapter only support "
|
||||
f"'bitsandbytes' load format, but got {self.load_format}")
|
||||
|
||||
if (self.load_format == "bitsandbytes" or
|
||||
self.qlora_adapter_name_or_path is not None) and \
|
||||
self.quantization != "bitsandbytes":
|
||||
raise ValueError(
|
||||
"BitsAndBytes load format and QLoRA adapter only support "
|
||||
f"'bitsandbytes' quantization, but got {self.quantization}")
|
||||
|
||||
device_config = DeviceConfig(self.device)
|
||||
model_config = ModelConfig(
|
||||
self.model, self.tokenizer, self.tokenizer_mode,
|
||||
@@ -622,6 +647,13 @@ class EngineArgs:
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
|
||||
if self.qlora_adapter_name_or_path is not None and \
|
||||
self.qlora_adapter_name_or_path != "":
|
||||
if self.model_loader_extra_config is None:
|
||||
self.model_loader_extra_config = {}
|
||||
self.model_loader_extra_config[
|
||||
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
|
||||
|
||||
load_config = LoadConfig(
|
||||
load_format=self.load_format,
|
||||
download_dir=self.download_dir,
|
||||
|
||||
Reference in New Issue
Block a user