Enable bnb for multiple indices weight (#35838)

Signed-off-by: xjx <493337577@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
xjx
2026-03-04 09:46:47 +08:00
committed by GitHub
parent f7da9cdffc
commit 9a9d442464

View File

@@ -744,10 +744,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
)
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
if (
use_bitsandbytes_4bit
and isinstance(loaded_shard_id, tuple)
and self.tp_size > 1
):
raise NotImplementedError(
"Shard id with multiple indices is not supported "
"for BNB quantization yet."
"for BNB quantization with TP yet."
)
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(output_sizes):
@@ -815,9 +819,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(loaded_shard_id)
)
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = self.tp_rank * shard_size
if not is_sharded_weight: