From a6c137521cf7218cc2da5f56aa3e68ad96aa76b1 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 25 Feb 2026 14:12:28 +0800 Subject: [PATCH] [Misc] Add shard_id validation for MergedColumnLinear (#35055) Signed-off-by: Isotr0py --- vllm/model_executor/layers/linear.py | 74 +++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6467c7d13..6db3907ff 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -66,15 +66,23 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ] -def adjust_marlin_shard(param, shard_size, shard_offset): - marlin_tile_size = getattr(param, "marlin_tile_size", None) +def adjust_marlin_shard( + param: Parameter, + shard_size: int, + shard_offset: int, +) -> tuple[int, int]: + marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None) if marlin_tile_size is None: return shard_size, shard_offset return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset): +def adjust_block_scale_shard( + weight_block_size: tuple[int, ...] | None, + shard_size: int, + shard_offset: int, +) -> tuple[int, int]: assert weight_block_size is not None block_n = weight_block_size[0] shard_offset = (shard_offset + block_n - 1) // block_n @@ -83,7 +91,9 @@ def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset): def adjust_bitsandbytes_4bit_shard( - param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str + param: Parameter, + shard_offsets: dict[str, tuple[int, int]], + loaded_shard_id: str, ) -> tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" @@ -97,7 +107,11 @@ def adjust_bitsandbytes_4bit_shard( return quantized_size, quantized_offset -def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): +def adjust_scalar_to_fused_array( + param_data: torch.Tensor, + loaded_weight: torch.Tensor, + shard_id: int | str, +) -> tuple[torch.Tensor, torch.Tensor]: """For fused modules (QKV and MLP) we have an array of length N that holds 1 scale for each "logical" matrix. So the param is an array of length N. The loaded_weight corresponds to @@ -117,12 +131,14 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): assert loaded_weight.shape[0] == 1 loaded_weight = loaded_weight[0] - return param[shard_id], loaded_weight + return param_data[shard_id], loaded_weight # TODO(Isotr0py): We might need a more flexible structure to handle # bitsandbytes shard offsets. -def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): +def left_shift_bitsandbytes_4bit_shard( + bnb_weight_attrs: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: """ Separate the BitsAndBytes 4-bit shard. @@ -681,12 +697,41 @@ class MergedColumnParallelLinear(ColumnParallelLinear): disable_tp=disable_tp, ) + def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None): + if loaded_shard_id is None: + return + if isinstance(loaded_shard_id, tuple): + for idx in loaded_shard_id: + if not (0 <= idx < len(self.output_sizes)): + raise ValueError( + f"Shard id index {idx} should be between 0 and " + f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}." + ) + if len(loaded_shard_id) > 1 and any( + b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:]) + ): + raise ValueError( + "Shard id with multiple indices should be consecutive. " + f"Got shard id {loaded_shard_id}." + ) + return + elif isinstance(loaded_shard_id, int): + if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes): + raise ValueError( + f"Shard id should be between 0 and {len(self.output_sizes) - 1}. " + f"Got shard id {loaded_shard_id}." + ) + return + raise ValueError("This line should not be reached") + def weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: tuple[int, ...] | int | None = None, ): + self.validate_shard_id(loaded_shard_id) + # FIXME(Isotr0py): Enable tuple shard_id for BNB quantization. if isinstance(loaded_shard_id, tuple): raise NotImplementedError( "Shard id with multiple indices is not supported in weight_loader, " @@ -874,6 +919,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_weight: torch.Tensor, loaded_shard_id: tuple[int, ...] | int | None = None, ): + self.validate_shard_id(loaded_shard_id) if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): if isinstance(param, PerTensorScaleParameter): param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) @@ -1005,6 +1051,18 @@ class QKVParallelLinear(ColumnParallelLinear): disable_tp=disable_tp, ) + def validate_shard_id(self, loaded_shard_id: str | None): + if loaded_shard_id is None: + return + if isinstance(loaded_shard_id, str): + if loaded_shard_id not in ["q", "k", "v"]: + raise ValueError( + "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', " + f"got shard id {loaded_shard_id}." + ) + return + raise ValueError("This line should not be reached") + def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { "q": 0, @@ -1073,6 +1131,7 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight: torch.Tensor, loaded_shard_id: str | None = None, ): + self.validate_shard_id(loaded_shard_id) if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): param.load_qkv_weight( @@ -1112,6 +1171,7 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_weight: torch.Tensor, loaded_shard_id: str | None = None, ): + self.validate_shard_id(loaded_shard_id) # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False)