[Bugfix] Fix weight_loader v1 block scale (#31103)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
This commit is contained in:
@@ -80,6 +80,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
|
||||
assert weight_block_size is not None
|
||||
block_n = weight_block_size[0]
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
shard_size = (shard_size + block_n - 1) // block_n
|
||||
return shard_size, shard_offset
|
||||
|
||||
|
||||
def adjust_bitsandbytes_4bit_shard(
|
||||
param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
|
||||
) -> tuple[int, int]:
|
||||
@@ -763,8 +771,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
|
||||
shard_offset //= self.tp_size
|
||||
shard_size //= self.tp_size
|
||||
|
||||
# Special case for quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
@@ -867,24 +885,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
assert self.quant_method is not None
|
||||
# Assume the weight block size has been set by quant method
|
||||
assert hasattr(self, "weight_block_size")
|
||||
weight_block_size = self.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (
|
||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
|
||||
) // self.tp_size
|
||||
shard_size = (
|
||||
(self.output_sizes[loaded_shard_id] + block_n - 1)
|
||||
// block_n
|
||||
// self.tp_size
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
else:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
|
||||
shard_offset //= self.tp_size
|
||||
shard_size //= self.tp_size
|
||||
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
@@ -1066,16 +1077,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
||||
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
|
||||
# Note(simon): This is needed for Qwen3's fp8 quantization.
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
assert self.quant_method is not None
|
||||
# Assume the weight block size has been set by quant method
|
||||
assert hasattr(self, "weight_block_size")
|
||||
weight_block_size = self.weight_block_size
|
||||
assert weight_block_size is not None
|
||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
shard_size = (shard_size + block_n - 1) // block_n
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
|
||||
param.load_qkv_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
@@ -1208,6 +1214,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.v_head_size
|
||||
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
|
||||
Reference in New Issue
Block a user