[Quantization] Modify the logic of BNB double quantization (#19742)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -492,8 +492,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
raise ValueError("Following weights were not initialized from "
|
raise ValueError("Following weights were not initialized from "
|
||||||
f"checkpoint: {weights_not_loaded}")
|
f"checkpoint: {weights_not_loaded}")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
param_dict = dict(model.named_parameters())
|
param_dict = dict(model.named_parameters())
|
||||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||||
# TODO: Change this lazy import to normal import
|
# TODO: Change this lazy import to normal import
|
||||||
@@ -545,6 +543,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
for param_name, param in param_dict.items():
|
for param_name, param in param_dict.items():
|
||||||
if param_name in stacked_quant_state_dict:
|
if param_name in stacked_quant_state_dict:
|
||||||
quant_states = stacked_quant_state_dict[param_name]
|
quant_states = stacked_quant_state_dict[param_name]
|
||||||
|
# Dequantize double quantized values during weight loading.
|
||||||
|
dequantize_dq(quant_states)
|
||||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||||
|
|
||||||
pack_ratio = getattr(param, "pack_factor", -1)
|
pack_ratio = getattr(param, "pack_factor", -1)
|
||||||
@@ -565,6 +565,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if load_8bit:
|
if load_8bit:
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
param, {"matmul_state": [None] * len(quant_states)})
|
param, {"matmul_state": [None] * len(quant_states)})
|
||||||
|
torch.cuda.empty_cache()
|
||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model, model_config.revision)
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_dq(quant_states: dict) -> None:
|
||||||
|
"""
|
||||||
|
When BNB employs Double Quantization, we perform the dequantization of
|
||||||
|
these constants during weight loading rather than at inference time,
|
||||||
|
thereby avoiding this computational overhead during inference. This comes
|
||||||
|
at the cost of increased memory usage.
|
||||||
|
"""
|
||||||
|
from bitsandbytes.functional import dequantize_blockwise
|
||||||
|
for _, quant_state in quant_states.items():
|
||||||
|
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||||
|
if quant_state.nested:
|
||||||
|
absmax = dequantize_blockwise(quant_state.absmax,
|
||||||
|
quant_state.state2)
|
||||||
|
absmax += quant_state.offset
|
||||||
|
if absmax.dtype != torch.float32:
|
||||||
|
absmax = absmax.float()
|
||||||
|
quant_state.absmax = absmax
|
||||||
|
quant_state.nested = False
|
||||||
|
quant_state.offset = None
|
||||||
|
quant_state.state2 = None
|
||||||
|
|||||||
Reference in New Issue
Block a user