support bitsandbytes quantization with more models (#9148)
This commit is contained in:
@@ -336,8 +336,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None:
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
@@ -821,6 +825,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size, self.total_num_kv_heads * self.head_size),
|
||||
]
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
@@ -834,6 +841,23 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.total_num_heads * self.head_size),
|
||||
"k": (self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
"v":
|
||||
((self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
"total":
|
||||
((self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_size, 0)
|
||||
}
|
||||
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, shard_id)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
|
||||
Reference in New Issue
Block a user