[Bugfix] Fix Qwen3.5-FP8 Weight Loading Error on TPU (#37348)

Signed-off-by: Jacob Platin <jacobplatin@google.com>
This commit is contained in:
Jacob Platin
2026-03-25 17:46:01 -07:00
committed by GitHub
parent 3c3c084240
commit d7d51a7ee5

View File

@@ -768,6 +768,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
# Add check to adjust the size/offset for FP8 block scales
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
@@ -1218,6 +1225,13 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
# Add check to adjust the size/offset for FP8 block scales
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor