diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 402f0bf69..ebdc05449 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -80,6 +80,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset): + assert weight_block_size is not None + block_n = weight_block_size[0] + shard_offset = (shard_offset + block_n - 1) // block_n + shard_size = (shard_size + block_n - 1) // block_n + return shard_size, shard_offset + + def adjust_bitsandbytes_4bit_shard( param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str ) -> tuple[int, int]: @@ -763,8 +771,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size - shard_size = self.output_sizes[loaded_shard_id] // self.tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + + if isinstance(param, BlockQuantScaleParameter): + weight_block_size = getattr(self, "weight_block_size", None) + shard_size, shard_offset = adjust_block_scale_shard( + weight_block_size, shard_size, shard_offset + ) + + shard_offset //= self.tp_size + shard_size //= self.tp_size + # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -867,24 +885,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + if isinstance(param, BlockQuantScaleParameter): - assert self.quant_method is not None - # Assume the weight block size has been set by quant method - assert hasattr(self, "weight_block_size") - weight_block_size = self.weight_block_size - assert weight_block_size is not None - block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n - ) // self.tp_size - shard_size = ( - (self.output_sizes[loaded_shard_id] + block_n - 1) - // block_n - // self.tp_size + weight_block_size = getattr(self, "weight_block_size", None) + shard_size, shard_offset = adjust_block_scale_shard( + weight_block_size, shard_size, shard_offset ) - else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size - shard_size = self.output_sizes[loaded_shard_id] // self.tp_size + + shard_offset //= self.tp_size + shard_size //= self.tp_size param.load_merged_column_weight( loaded_weight=loaded_weight, @@ -1066,16 +1077,11 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id) - # Note(simon): This is needed for Qwen3's fp8 quantization. if isinstance(param, BlockQuantScaleParameter): - assert self.quant_method is not None - # Assume the weight block size has been set by quant method - assert hasattr(self, "weight_block_size") - weight_block_size = self.weight_block_size - assert weight_block_size is not None - block_n, _ = weight_block_size[0], weight_block_size[1] - shard_offset = (shard_offset + block_n - 1) // block_n - shard_size = (shard_size + block_n - 1) // block_n + weight_block_size = getattr(self, "weight_block_size", None) + shard_size, shard_offset = adjust_block_scale_shard( + weight_block_size, shard_size, shard_offset + ) param.load_qkv_weight( loaded_weight=loaded_weight, @@ -1208,6 +1214,13 @@ class QKVParallelLinear(ColumnParallelLinear): elif loaded_shard_id == "v": shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.v_head_size + + if isinstance(param, BlockQuantScaleParameter): + weight_block_size = getattr(self, "weight_block_size", None) + shard_size, shard_offset = adjust_block_scale_shard( + weight_block_size, shard_size, shard_offset + ) + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing.