Enable bnb for multiple indices weight (#35838)
Signed-off-by: xjx <493337577@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -744,10 +744,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
)
|
)
|
||||||
current_shard_offset = 0
|
current_shard_offset = 0
|
||||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||||
if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
|
if (
|
||||||
|
use_bitsandbytes_4bit
|
||||||
|
and isinstance(loaded_shard_id, tuple)
|
||||||
|
and self.tp_size > 1
|
||||||
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shard id with multiple indices is not supported "
|
"Shard id with multiple indices is not supported "
|
||||||
"for BNB quantization yet."
|
"for BNB quantization with TP yet."
|
||||||
)
|
)
|
||||||
shard_offsets: list[tuple[int, int, int]] = []
|
shard_offsets: list[tuple[int, int, int]] = []
|
||||||
for i, output_size in enumerate(output_sizes):
|
for i, output_size in enumerate(output_sizes):
|
||||||
@@ -815,9 +819,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||||
|
|
||||||
if use_bitsandbytes_4bit:
|
if use_bitsandbytes_4bit:
|
||||||
shard_size = loaded_weight.shape[output_dim]
|
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
orig_offsets = {
|
||||||
|
str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
|
||||||
|
}
|
||||||
|
orig_offsets["total"] = (self.output_size, 0)
|
||||||
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||||
|
param, orig_offsets, str(loaded_shard_id)
|
||||||
|
)
|
||||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
if not is_sharded_weight:
|
if not is_sharded_weight:
|
||||||
|
|||||||
Reference in New Issue
Block a user