[Bugfix] Fix Qwen3.5-FP8 Weight Loading Error on TPU (#37348)
Signed-off-by: Jacob Platin <jacobplatin@google.com>
This commit is contained in:
@@ -768,6 +768,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
# Add check to adjust the size/offset for FP8 block scales
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.packed_factor
|
||||
shard_offset = shard_offset // param.packed_factor
|
||||
@@ -1218,6 +1225,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
# Add check to adjust the size/offset for FP8 block scales
|
||||
if isinstance(param, BlockQuantScaleParameter):
|
||||
weight_block_size = getattr(self, "weight_block_size", None)
|
||||
shard_size, shard_offset = adjust_block_scale_shard(
|
||||
weight_block_size, shard_size, shard_offset
|
||||
)
|
||||
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.packed_factor
|
||||
shard_offset = shard_offset // param.packed_factor
|
||||
|
||||
Reference in New Issue
Block a user