diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6db3907ff..5fc9fa073 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -731,16 +731,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_shard_id: tuple[int, ...] | int | None = None, ): self.validate_shard_id(loaded_shard_id) - # FIXME(Isotr0py): Enable tuple shard_id for BNB quantization. - if isinstance(loaded_shard_id, tuple): - raise NotImplementedError( - "Shard id with multiple indices is not supported in weight_loader, " - "please use weight_loader_v2 instead." - ) # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if isinstance(loaded_shard_id, tuple) and ( + is_gguf_weight or is_gguf_weight_type + ): + raise NotImplementedError( + "Shard id with multiple indices is not supported for GGUF." + ) if is_gguf_weight_type: if loaded_shard_id is not None: param.data[loaded_shard_id].copy_(loaded_weight) @@ -768,7 +768,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Special case for per-tensor scale to load scalar into fused array. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) - if loaded_shard_id is None: + if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): # Loaded weight is already fused on disk (mlp). # (e.g., Phi-3's gate_up_proj). if output_dim is None: @@ -780,10 +780,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return + + output_sizes = ( + self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1] + if loaded_shard_id is not None + else self.output_sizes + ) current_shard_offset = 0 use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple): + raise NotImplementedError( + "Shard id with multiple indices is not supported " + "for BNB quantization yet." + ) shard_offsets: list[tuple[int, int, int]] = [] - for i, output_size in enumerate(self.output_sizes): + for i, output_size in enumerate(output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size packed_dim = getattr(param, "packed_dim", None)