diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f0d06e179..bfcdaa4c0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -744,10 +744,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ) current_shard_offset = 0 use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) - if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple): + if ( + use_bitsandbytes_4bit + and isinstance(loaded_shard_id, tuple) + and self.tp_size > 1 + ): raise NotImplementedError( "Shard id with multiple indices is not supported " - "for BNB quantization yet." + "for BNB quantization with TP yet." ) shard_offsets: list[tuple[int, int, int]] = [] for i, output_size in enumerate(output_sizes): @@ -815,9 +819,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit if use_bitsandbytes_4bit: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id - + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(loaded_shard_id) + ) param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = self.tp_rank * shard_size if not is_sharded_weight: