Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -18,27 +18,37 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
# yapf: enable
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (ParamMapping,
set_default_torch_dtype)
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
download_safetensors_index_file_from_hf,
download_weights_from_hf,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import (get_moe_expert_mapping,
get_packed_modules_mapping,
set_weight_attrs)
from vllm.model_executor.utils import (
get_moe_expert_mapping,
get_packed_modules_mapping,
set_weight_attrs,
)
from vllm.platforms import current_platform
# yapf conflicts with isort for this block
@@ -48,8 +58,7 @@ logger = init_logger(__name__)
def is_moe_model(model: torch.nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers."""
return bool(any(
isinstance(module, FusedMoE) for module in model.modules()))
return bool(any(isinstance(module, FusedMoE) for module in model.modules()))
class BitsAndBytesModelLoader(BaseModelLoader):
@@ -92,8 +101,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(
os.path.join(model_name_or_path, pattern))
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
if weight_files:
return model_name_or_path, weight_files, pattern
else:
@@ -109,20 +117,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return hf_folder, glob.glob(
os.path.join(hf_folder, pattern)), pattern
return (
hf_folder,
glob.glob(os.path.join(hf_folder, pattern)),
pattern,
)
raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`")
raise RuntimeError(f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> tuple[list[str], bool]:
def _prepare_weights(
self, model_name_or_path: str, revision: Optional[str]
) -> tuple[list[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision)
model_name_or_path, allowed_patterns, revision
)
use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
@@ -141,25 +153,27 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name: str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if self.is_pool_model and self.target_modules[0]. \
startswith("model.") and not module_name.startswith(
"model."):
if (
self.is_pool_model
and self.target_modules[0].startswith("model.")
and not module_name.startswith("model.")
):
return "model." + module_name
return module_name
@@ -187,8 +201,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self,
model_name_or_path: str,
revision: Optional[str],
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
Any]]:
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
@@ -196,37 +209,41 @@ class BitsAndBytesModelLoader(BaseModelLoader):
try:
import bitsandbytes
if version.parse(
bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1.")
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer.") from err
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision)
model_name_or_path, revision
)
quant_state_dict: dict[str, Any] = {}
if self.pre_quant:
if self.load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
else:
return self._quantized_4bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
return self._unquantized_generator(
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix)
for suffix in quantized_suffix)
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
@@ -239,12 +256,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
def _quantized_8bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue
@@ -253,9 +271,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict[weight_key] = weight_tensor
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue
@@ -266,18 +284,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
else:
yield org_weight_name, weight_tensor
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
def _quantized_4bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import QuantState
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
temp_state_dict = {}
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
@@ -289,98 +307,111 @@ class BitsAndBytesModelLoader(BaseModelLoader):
temp_state_dict[mapped_weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: dict) -> QuantState:
def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state,
device=current_platform.device_type)
return QuantState.from_dict(
quant_state, device=current_platform.device_type
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or (
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict):
quant_state = _parse_quant_state(mapped_weight_name,
temp_state_dict)
if (
f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict
) or (
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
):
quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
def _unquantized_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import quantize_4bit
global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()
check_match = (lambda weight_name, module_name: weight_name.
removesuffix(".weight") == module_name)
check_match = (
lambda weight_name, module_name: weight_name.removesuffix(".weight")
== module_name
)
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
# override tp_size and tp_rank if the module has disabled TP
if any(tp_disabled_module in mapped_weight_name
for tp_disabled_module in self.tp_disabled_modules):
if any(
tp_disabled_module in mapped_weight_name
for tp_disabled_module in self.tp_disabled_modules
):
tp_size = 1
tp_rank = 0
else:
tp_size = global_tp_size
tp_rank = global_tp_rank
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
if any(
target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
check_match(mapped_weight_name, module)
for module in self.unsharded_weights_modules):
check_match(mapped_weight_name, module)
for module in self.unsharded_weights_modules
):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
check_match(mapped_weight_name, module)
for module in self.column_sharded_weights_modules):
check_match(mapped_weight_name, module)
for module in self.column_sharded_weights_modules
):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
weight_sub_tensor = weight_tensor[..., start_index:end_index]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
check_match(mapped_weight_name, module)
for module in self.maybe_fused_weights_modules):
check_match(mapped_weight_name, module)
for module in self.maybe_fused_weights_modules
):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if check_match(mapped_weight_name, module)))
(
sizes
for module, sizes in self.maybe_fused_weights_modules.items()
if check_match(mapped_weight_name, module)
)
)
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
total_start_index = list(
itertools.accumulate([0] + total_shard_sizes))[:-1]
shard_weights_index = [(
idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1),
) for idx, size in zip(total_start_index,
total_shard_sizes)]
itertools.accumulate([0] + total_shard_sizes)
)[:-1]
shard_weights_index = [
(
idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1),
)
for idx, size in zip(total_start_index, total_shard_sizes)
]
# slice and reorder the weight tensor
weight_tensor = [
weight_tensor[start_index:end_index, ...]
@@ -392,15 +423,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index,
...]
weight_sub_tensor = weight_tensor[start_index:end_index, ...]
# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.to(
device=current_platform.device_type)
device=current_platform.device_type
)
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
@@ -421,12 +452,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _get_bnb_target_modules(self, model: nn.Module) -> None:
"""
Identify and collect all modules that support BitsAndBytes
Identify and collect all modules that support BitsAndBytes
quantization.
"""
for name, module in model.named_modules():
if (isinstance(module, LinearBase)
and hasattr(module.quant_method, "quant_config")):
if isinstance(module, LinearBase) and hasattr(
module.quant_method, "quant_config"
):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
@@ -442,45 +474,48 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if module.disable_tp:
self.tp_disabled_modules.append(name)
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"):
module.quant_method, "quant_config"
):
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant and self.load_8bit:
raise ValueError(
"Prequant BitsAndBytes 8bit models with FusedMoE "
"is not supported yet.")
"is not supported yet."
)
# Get the corresponding weight name using module name and
# expert_params_mapping.
for exp in self.expert_params_mapping:
weight_name = exp[1]
rep_name = name.replace("experts",
"") + weight_name.removesuffix(".")
rep_name = name.replace("experts", "") + weight_name.removesuffix(
"."
)
self.target_modules.append(rep_name)
assert (self.target_modules
), "vLLM currently does not support BNB quantization for"
assert self.target_modules, (
"vLLM currently does not support BNB quantization for"
)
f" {type(model).__name__}"
def _classify_module_sharding(self, model: nn.Module):
"""
Categorize modules based on their weight sharding requirements
Categorize modules based on their weight sharding requirements
for tensor parallelism.
"""
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if isinstance(module, (ReplicatedLinear, )):
if isinstance(module, (ReplicatedLinear,)):
self.unsharded_weights_modules.append(name)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif isinstance(module,
(QKVParallelLinear, MergedColumnParallelLinear)):
elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)):
self.maybe_fused_weights_modules[name] = module.output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear, )):
elif isinstance(module, (RowParallelLinear,)):
self.column_sharded_weights_modules.append(name)
elif isinstance(module, FusedMoE):
expert_mapping = self.expert_params_mapping
@@ -488,48 +523,53 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if exp[-1] == "w2":
weight_name = exp[1]
rep_name = name.replace(
"experts", "") + weight_name.removesuffix(".")
"experts", ""
) + weight_name.removesuffix(".")
self.column_sharded_weights_modules.append(rep_name)
def _verify_model_compatibility(self, model: nn.Module,
model_config: ModelConfig) -> None:
def _verify_model_compatibility(
self, model: nn.Module, model_config: ModelConfig
) -> None:
"""
Verify that the model is compatible with BitsAndBytes quantization.
"""
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")
f" {type(model).__name__}."
)
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
"quantization yet. No 'packed_modules_mapping' found."
)
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
quant_config = getattr(model_config.hf_config, "quantization_config", None)
if quant_config is not None:
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
self.pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
"quantization")
f"BitsAndBytes loader does not support {quant_method} quantization"
)
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism.")
"supported. Please try with pipeline parallelism."
)
if self.pre_quant:
self.load_8bit = quant_config.get("load_in_8bit", False)
def _initialize_loader_state(self, model: nn.Module,
model_config: ModelConfig) -> None:
def _initialize_loader_state(
self, model: nn.Module, model_config: ModelConfig
) -> None:
"""
Initialize the loader's internal state based on the model and
Initialize the loader's internal state based on the model and
configuration.
"""
self.is_pool_model = is_pooling_model(model)
@@ -541,7 +581,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
"'get_expert_mapping' method."
)
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
@@ -552,22 +593,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _dequantize_dq(self, quant_states: Any):
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This
comes at the cost of increased memory usage.
"""
from bitsandbytes.functional import QuantState, dequantize_blockwise
def _dequantize_single_state(quant_state):
"""Helper function to dequantize a single QuantState object."""
if not (isinstance(quant_state, QuantState)
and quant_state.nested):
if not (isinstance(quant_state, QuantState) and quant_state.nested):
return
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
# Ensure float32 dtype
@@ -586,10 +625,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
_dequantize_single_state(quant_states)
return quant_states
def _fuse_moe_quant_states(self, model: nn.Module,
quant_states_dict: dict) -> dict:
def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict:
"""
This function consolidates individual expert quantization states into
fused representations for w13 and w2.
"""
@@ -609,12 +647,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for exp in expert_mapping:
shard_id = exp[-1]
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
raise ValueError(
f"shard_id must be ['w1','w2','w3'] but got {shard_id}."
)
layer_prefix = name.split("experts")[0]
weight_qual_name = layer_prefix + exp[1] + "weight"
quant_state = self._dequantize_dq(
quant_states_dict[weight_qual_name])
quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name])
if shard_id == "w1":
w1_states_lst.append(quant_state)
elif shard_id == "w2":
@@ -622,14 +660,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
else:
w3_states_lst.append(quant_state)
del quant_states_dict[weight_qual_name]
assert (len(w1_states_lst) == len(w2_states_lst) ==
len(w3_states_lst))
assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst)
w13_absmax_lst = []
w2_absmax_lst = []
w13_total_dim0 = 0
w2_total_dim0 = 0
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst,
w3_states_lst):
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst):
assert w1_qs.shape == w3_qs.shape
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
@@ -669,12 +705,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return expert_qs_dict
def _stack_quantization_states(
self, model: nn.Module,
quant_state_dict: dict) -> dict[str, dict[int, Any]]:
self, model: nn.Module, quant_state_dict: dict
) -> dict[str, dict[int, Any]]:
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
param_dict = dict(model.named_parameters())
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
@@ -684,23 +721,23 @@ class BitsAndBytesModelLoader(BaseModelLoader):
shard_index = 0
for shard_name, (
weight_name,
index,
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
shard_pos = quant_param_name.find(shard_name)
can_correct_rename = (shard_pos
> 0) and (quant_param_name[shard_pos - 1]
== ".")
can_correct_rename = (shard_pos > 0) and (
quant_param_name[shard_pos - 1] == "."
)
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name = quant_param_name.replace(
shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) \
and (new_quant_param_name in param_dict)
new_quant_param_name = quant_param_name.replace(shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) and (
new_quant_param_name in param_dict
)
if can_correct_rename and need_rename:
shard_index = index
quant_param_name = new_quant_param_name
@@ -714,12 +751,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = (
quant_state_dict[non_stacked_param_name])
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
non_stacked_param_name
]
return stacked_quant_state_dict
def _bind_quant_states_to_params(self, model: nn.Module,
stacked_quant_state_dict: dict) -> None:
def _bind_quant_states_to_params(
self, model: nn.Module, stacked_quant_state_dict: dict
) -> None:
# save quant_states and offsets as the attributes of the parameters
param_dict = dict(model.named_parameters())
for param_name, param in param_dict.items():
@@ -733,13 +772,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(
f"pack_factor not set for parameter {param_name}.")
raise ValueError(f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items():
num_elements[seq] = (math.prod(quant_state.shape) //
pack_ratio)
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy
@@ -748,38 +785,39 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
param, {"matmul_state": [None] * len(quant_states)}
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
))
logger.info(
"Loading weights with BitsAndBytes quantization. May take a while ..."
)
qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
expert_quant_state_dict = self._fuse_moe_quant_states(
model, quant_state_dict)
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)
expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict)
stacked_quant_state_dict = self._stack_quantization_states(
model, quant_state_dict)
model, quant_state_dict
)
stacked_quant_state_dict = {
**expert_quant_state_dict,
**stacked_quant_state_dict
**stacked_quant_state_dict,
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()