[[Misc]Upgrade bitsandbytes to the latest version 0.44.0 (#8768)
This commit is contained in:
@@ -222,6 +222,7 @@ class ModelConfig:
|
||||
self._verify_embedding_mode()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
|
||||
def _init_multimodal_config(
|
||||
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
|
||||
@@ -337,6 +338,28 @@ class ModelConfig:
|
||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||
self.max_model_len)
|
||||
|
||||
def _verify_bnb_config(self) -> None:
|
||||
"""
|
||||
The current version of bitsandbytes (0.44.0) with 8-bit models does not
|
||||
yet support CUDA graph.
|
||||
"""
|
||||
is_bitsandbytes = self.quantization == "bitsandbytes"
|
||||
has_quantization_config = (getattr(self.hf_config,
|
||||
"quantization_config", None)
|
||||
is not None)
|
||||
is_8bit = (self.hf_config.quantization_config.get(
|
||||
"load_in_8bit", False) if has_quantization_config else False)
|
||||
if all([
|
||||
is_bitsandbytes,
|
||||
has_quantization_config,
|
||||
is_8bit,
|
||||
not self.enforce_eager,
|
||||
]):
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on BitAndBytes 8bit yet, "
|
||||
"fallback to the eager mode.")
|
||||
self.enforce_eager = True
|
||||
|
||||
def verify_async_output_proc(self, parallel_config, speculative_config,
|
||||
device_config) -> None:
|
||||
if not self.use_async_output_proc:
|
||||
@@ -401,13 +424,6 @@ class ModelConfig:
|
||||
"Pipeline parallelism is only supported for the following "
|
||||
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
||||
|
||||
# Remove the constraint after the bitsandbytes issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
|
||||
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
|
||||
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
|
||||
"fallback to the eager mode.")
|
||||
self.enforce_eager = True
|
||||
|
||||
if pipeline_parallel_size > 1 and self.use_async_output_proc:
|
||||
logger.warning("Async output processor is not supported with "
|
||||
"pipeline parallelism currently. Disabling it.")
|
||||
|
||||
Reference in New Issue
Block a user