[Bugfix] Fix weight_loader v1 block scale (#31103)

Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
This commit is contained in:
Kyuyeun Kim
2026-01-01 21:14:10 -08:00
committed by GitHub
parent 825c2dc133
commit cc410e8644

View File

@@ -80,6 +80,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
assert weight_block_size is not None
block_n = weight_block_size[0]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
return shard_size, shard_offset
def adjust_bitsandbytes_4bit_shard(
param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
@@ -763,8 +771,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
shard_offset //= self.tp_size
shard_size //= self.tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
@@ -867,24 +885,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
if isinstance(param, BlockQuantScaleParameter):
assert self.quant_method is not None
# Assume the weight block size has been set by quant method
assert hasattr(self, "weight_block_size")
weight_block_size = self.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
) // self.tp_size
shard_size = (
(self.output_sizes[loaded_shard_id] + block_n - 1)
// block_n
// self.tp_size
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
shard_offset //= self.tp_size
shard_size //= self.tp_size
param.load_merged_column_weight(
loaded_weight=loaded_weight,
@@ -1066,16 +1077,11 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
# Note(simon): This is needed for Qwen3's fp8 quantization.
if isinstance(param, BlockQuantScaleParameter):
assert self.quant_method is not None
# Assume the weight block size has been set by quant method
assert hasattr(self, "weight_block_size")
weight_block_size = self.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
param.load_qkv_weight(
loaded_weight=loaded_weight,
@@ -1208,6 +1214,13 @@ class QKVParallelLinear(ColumnParallelLinear):
elif loaded_shard_id == "v":
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.v_head_size
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.