[Bugfix][Quantization] Fix PerTensorScale loading with tuple shard_id in MergedColumnParallelLinear (#38517)

Signed-off-by: loukang <loukang@xiaohongshu.com>
This commit is contained in:
kkyyxhll
2026-04-07 23:16:26 +08:00
committed by GitHub
parent 729eb59f60
commit 98e1a43af7

View File

@@ -910,7 +910,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
self.validate_shard_id(loaded_shard_id)
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
if isinstance(loaded_shard_id, tuple):
for idx in loaded_shard_id:
param.load_merged_column_weight(
loaded_weight=loaded_weight, shard_id=idx
)
else:
param.load_merged_column_weight(
loaded_weight=loaded_weight, shard_id=0
)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)