Enable bnb for multiple indices weight (#35838)
Signed-off-by: xjx <493337577@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -744,10 +744,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
)
|
||||
current_shard_offset = 0
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
|
||||
if (
|
||||
use_bitsandbytes_4bit
|
||||
and isinstance(loaded_shard_id, tuple)
|
||||
and self.tp_size > 1
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported "
|
||||
"for BNB quantization yet."
|
||||
"for BNB quantization with TP yet."
|
||||
)
|
||||
shard_offsets: list[tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(output_sizes):
|
||||
@@ -815,9 +819,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
||||
|
||||
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||
orig_offsets = {
|
||||
str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
|
||||
}
|
||||
orig_offsets["total"] = (self.output_size, 0)
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_offsets, str(loaded_shard_id)
|
||||
)
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
if not is_sharded_weight:
|
||||
|
||||
Reference in New Issue
Block a user