diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 07dc2cb7f..975fedabd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -910,7 +910,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): self.validate_shard_id(loaded_shard_id) if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): if isinstance(param, PerTensorScaleParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) + if isinstance(loaded_shard_id, tuple): + for idx in loaded_shard_id: + param.load_merged_column_weight( + loaded_weight=loaded_weight, shard_id=idx + ) + else: + param.load_merged_column_weight( + loaded_weight=loaded_weight, shard_id=0 + ) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight)