[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-07-11 16:01:13 +08:00
committed by GitHub
parent 762be26a8e
commit 8020e98c9f
8 changed files with 561 additions and 88 deletions

View File

@@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
# yapf: enable
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
elif (isinstance(module, FusedMoE)
and hasattr(module.quant_method, "quant_config")):
if not hasattr(model, "get_expert_mapping"):
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant:
raise ValueError(
"Prequant BitsAndBytes models with FusedMoE is not "
"supported yet.")
if self.load_8bit:
raise ValueError(
"BitsAndBytes 8bit quantization with FusedMoE is not "
"supported yet.")
# Get the corresponding weight name using module name and
# get_expert_mapping.
expert_mapping = model.get_expert_mapping()
for exp in expert_mapping:
weight_name = exp[1]
rep_name = name.replace("experts",
"") + weight_name.removesuffix(".")
self.target_modules.append(rep_name)
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
), "vLLM currently does not support BNB quantization for"
f" {type(model).__name__}"
def _classify_module_sharding(self, model: nn.Module):
@@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
elif isinstance(module, FusedMoE):
expert_mapping = model.get_expert_mapping()
for exp in expert_mapping:
if exp[-1] == "w2":
weight_name = exp[1]
rep_name = name.replace(
"experts", "") + weight_name.removesuffix(".")
self.column_sharded_weights_modules.append(rep_name)
def _verify_model_compatibility(self, model: nn.Module,
model_config: ModelConfig) -> None:
@@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self._get_bnb_target_modules(model)
self._classify_module_sharding(model)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
def _dequantize_dq(self, quant_states: Any):
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This
comes at the cost of increased memory usage.
"""
from bitsandbytes.functional import QuantState, dequantize_blockwise
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
def _dequantize_single_state(quant_state):
"""Helper function to dequantize a single QuantState object."""
if not (isinstance(quant_state, QuantState)
and quant_state.nested):
return
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax += quant_state.offset
param_dict = dict(model.named_parameters())
# Ensure float32 dtype
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None
if isinstance(quant_states, dict):
for quant_state in quant_states.values():
_dequantize_single_state(quant_state)
else:
_dequantize_single_state(quant_states)
return quant_states
def _fuse_moe_quant_states(self, model: nn.Module,
quant_states_dict: dict) -> dict:
"""
This function consolidates individual expert quantization states into
fused representations for w13 and w2.
"""
from bitsandbytes.functional import QuantState
if not hasattr(model, "get_expert_mapping"):
return dict()
expert_mapping = model.get_expert_mapping()
expert_qs_dict = {}
for name, module in model.named_modules():
if not isinstance(module, FusedMoE):
continue
w1_states_lst = []
w2_states_lst = []
w3_states_lst = []
for exp in expert_mapping:
shard_id = exp[-1]
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
layer_prefix = name.split("experts")[0]
weight_qual_name = layer_prefix + exp[1] + "weight"
quant_state = self._dequantize_dq(
quant_states_dict[weight_qual_name])
if shard_id == "w1":
w1_states_lst.append(quant_state)
elif shard_id == "w2":
w2_states_lst.append(quant_state)
else:
w3_states_lst.append(quant_state)
del quant_states_dict[weight_qual_name]
assert (len(w1_states_lst) == len(w2_states_lst) ==
len(w3_states_lst))
w13_absmax_lst = []
w2_absmax_lst = []
w13_total_dim0 = 0
w2_total_dim0 = 0
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst,
w3_states_lst):
assert w1_qs.shape == w3_qs.shape
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
# w1 and w3 are interleaved in storage
w13_absmax_lst.append(w1_qs.absmax)
w13_absmax_lst.append(w3_qs.absmax)
w2_absmax_lst.append(w2_qs.absmax)
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
w2_total_dim0 += w2_qs.shape[0]
w13_absmax = torch.cat(w13_absmax_lst)
w2_absmax = torch.cat(w2_absmax_lst)
# Create fused quantization state for w13.
w13_qs = QuantState(
absmax=w13_absmax,
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
code=w1_states_lst[0].code,
blocksize=w1_states_lst[0].blocksize,
quant_type="nf4",
dtype=w1_states_lst[0].dtype,
)
# Create fused quantization state for w2.
w2_qs = QuantState(
absmax=w2_absmax,
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
code=w2_states_lst[0].code,
blocksize=w2_states_lst[0].blocksize,
quant_type="nf4",
dtype=w2_states_lst[0].dtype,
)
# The weight suffixes .w13_weight and .w2_weight are consistent
# with the param in BitsAndBytesMoEMethod.
w13_weight_name = name + ".w13_weight"
w2_weight_name = name + ".w2_weight"
expert_qs_dict[w13_weight_name] = w13_qs
expert_qs_dict[w2_weight_name] = w2_qs
return expert_qs_dict
def _stack_quantization_states(
self, model: nn.Module,
quant_state_dict: dict) -> dict[str, dict[int, Any]]:
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
param_dict = dict(model.named_parameters())
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
@@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
stacked_quant_state_dict[quant_param_name][shard_index] = (
quant_state_dict[non_stacked_param_name])
return stacked_quant_state_dict
def _bind_quant_states_to_params(self, model: nn.Module,
stacked_quant_state_dict: dict) -> None:
# save quant_states and offsets as the attributes of the parameters
param_dict = dict(model.named_parameters())
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
# Dequantize double quantized values during weight loading.
dequantize_dq(quant_states)
self._dequantize_dq(quant_states)
set_weight_attrs(param, {"bnb_quant_state": quant_states})
if not isinstance(quant_states, dict):
continue
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
@@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
expert_quant_state_dict = self._fuse_moe_quant_states(
model, quant_state_dict)
stacked_quant_state_dict = self._stack_quantization_states(
model, quant_state_dict)
stacked_quant_state_dict = {
**expert_quant_state_dict,
**stacked_quant_state_dict
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def dequantize_dq(quant_states: dict) -> None:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This comes
at the cost of increased memory usage.
"""
from bitsandbytes.functional import QuantState, dequantize_blockwise
for _, quant_state in quant_states.items():
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
if isinstance(quant_state, QuantState) and quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None