support bitsandbytes quantization with more models (#9148)

This commit is contained in:
chenqianfzh
2024-10-08 18:52:19 -07:00
committed by GitHub
parent 9ba0bd6aa6
commit 2f4117c38e
10 changed files with 165 additions and 28 deletions

View File

@@ -336,8 +336,12 @@ class ColumnParallelLinear(LinearBase):
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
param_data = param.data
if output_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
@@ -821,6 +825,9 @@ class QKVParallelLinear(ColumnParallelLinear):
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
@@ -834,6 +841,23 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
"v":
((self.total_num_heads + self.total_num_kv_heads) *
self.head_size,
self.total_num_kv_heads * self.head_size),
"total":
((self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_size, 0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)