[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
elif (isinstance(module, FusedMoE)
|
||||
and hasattr(module.quant_method, "quant_config")):
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method.")
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with FusedMoE is not "
|
||||
"supported yet.")
|
||||
if self.load_8bit:
|
||||
raise ValueError(
|
||||
"BitsAndBytes 8bit quantization with FusedMoE is not "
|
||||
"supported yet.")
|
||||
# Get the corresponding weight name using module name and
|
||||
# get_expert_mapping.
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
for exp in expert_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts",
|
||||
"") + weight_name.removesuffix(".")
|
||||
self.target_modules.append(rep_name)
|
||||
|
||||
assert (self.target_modules
|
||||
), "vllm currently does not support BNB quantization for"
|
||||
), "vLLM currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _classify_module_sharding(self, model: nn.Module):
|
||||
@@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# dimension (dim=-1)
|
||||
elif isinstance(module, (RowParallelLinear, )):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace(
|
||||
"experts", "") + weight_name.removesuffix(".")
|
||||
self.column_sharded_weights_modules.append(rep_name)
|
||||
|
||||
def _verify_model_compatibility(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
@@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
self._get_bnb_target_modules(model)
|
||||
self._classify_module_sharding(model)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
def _dequantize_dq(self, quant_states: Any):
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This
|
||||
comes at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
def _dequantize_single_state(quant_state):
|
||||
"""Helper function to dequantize a single QuantState object."""
|
||||
if not (isinstance(quant_state, QuantState)
|
||||
and quant_state.nested):
|
||||
return
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
))
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
absmax = dequantize_blockwise(quant_state.absmax,
|
||||
quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
# Ensure float32 dtype
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
if isinstance(quant_states, dict):
|
||||
for quant_state in quant_states.values():
|
||||
_dequantize_single_state(quant_state)
|
||||
else:
|
||||
_dequantize_single_state(quant_states)
|
||||
return quant_states
|
||||
|
||||
def _fuse_moe_quant_states(self, model: nn.Module,
|
||||
quant_states_dict: dict) -> dict:
|
||||
"""
|
||||
|
||||
This function consolidates individual expert quantization states into
|
||||
fused representations for w13 and w2.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
return dict()
|
||||
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
continue
|
||||
w1_states_lst = []
|
||||
w2_states_lst = []
|
||||
w3_states_lst = []
|
||||
for exp in expert_mapping:
|
||||
shard_id = exp[-1]
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
layer_prefix = name.split("experts")[0]
|
||||
weight_qual_name = layer_prefix + exp[1] + "weight"
|
||||
quant_state = self._dequantize_dq(
|
||||
quant_states_dict[weight_qual_name])
|
||||
if shard_id == "w1":
|
||||
w1_states_lst.append(quant_state)
|
||||
elif shard_id == "w2":
|
||||
w2_states_lst.append(quant_state)
|
||||
else:
|
||||
w3_states_lst.append(quant_state)
|
||||
del quant_states_dict[weight_qual_name]
|
||||
assert (len(w1_states_lst) == len(w2_states_lst) ==
|
||||
len(w3_states_lst))
|
||||
w13_absmax_lst = []
|
||||
w2_absmax_lst = []
|
||||
w13_total_dim0 = 0
|
||||
w2_total_dim0 = 0
|
||||
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst,
|
||||
w3_states_lst):
|
||||
assert w1_qs.shape == w3_qs.shape
|
||||
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
|
||||
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
|
||||
# w1 and w3 are interleaved in storage
|
||||
w13_absmax_lst.append(w1_qs.absmax)
|
||||
w13_absmax_lst.append(w3_qs.absmax)
|
||||
w2_absmax_lst.append(w2_qs.absmax)
|
||||
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
|
||||
w2_total_dim0 += w2_qs.shape[0]
|
||||
|
||||
w13_absmax = torch.cat(w13_absmax_lst)
|
||||
w2_absmax = torch.cat(w2_absmax_lst)
|
||||
# Create fused quantization state for w13.
|
||||
w13_qs = QuantState(
|
||||
absmax=w13_absmax,
|
||||
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
|
||||
code=w1_states_lst[0].code,
|
||||
blocksize=w1_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w1_states_lst[0].dtype,
|
||||
)
|
||||
# Create fused quantization state for w2.
|
||||
w2_qs = QuantState(
|
||||
absmax=w2_absmax,
|
||||
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
|
||||
code=w2_states_lst[0].code,
|
||||
blocksize=w2_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w2_states_lst[0].dtype,
|
||||
)
|
||||
# The weight suffixes .w13_weight and .w2_weight are consistent
|
||||
# with the param in BitsAndBytesMoEMethod.
|
||||
w13_weight_name = name + ".w13_weight"
|
||||
w2_weight_name = name + ".w2_weight"
|
||||
expert_qs_dict[w13_weight_name] = w13_qs
|
||||
expert_qs_dict[w2_weight_name] = w2_qs
|
||||
return expert_qs_dict
|
||||
|
||||
def _stack_quantization_states(
|
||||
self, model: nn.Module,
|
||||
quant_state_dict: dict) -> dict[str, dict[int, Any]]:
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
@@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = (
|
||||
quant_state_dict[non_stacked_param_name])
|
||||
return stacked_quant_state_dict
|
||||
|
||||
def _bind_quant_states_to_params(self, model: nn.Module,
|
||||
stacked_quant_state_dict: dict) -> None:
|
||||
# save quant_states and offsets as the attributes of the parameters
|
||||
param_dict = dict(model.named_parameters())
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
# Dequantize double quantized values during weight loading.
|
||||
dequantize_dq(quant_states)
|
||||
self._dequantize_dq(quant_states)
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
if not isinstance(quant_states, dict):
|
||||
continue
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
if pack_ratio == -1:
|
||||
@@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if self.load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)})
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
))
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
expert_quant_state_dict = self._fuse_moe_quant_states(
|
||||
model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = self._stack_quantization_states(
|
||||
model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = {
|
||||
**expert_quant_state_dict,
|
||||
**stacked_quant_state_dict
|
||||
}
|
||||
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
|
||||
def dequantize_dq(quant_states: dict) -> None:
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This comes
|
||||
at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
for _, quant_state in quant_states.items():
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
if isinstance(quant_state, QuantState) and quant_state.nested:
|
||||
absmax = dequantize_blockwise(quant_state.absmax,
|
||||
quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
Reference in New Issue
Block a user