[Bugfix][Quantization] Fix PerTensorScale loading with tuple shard_id in MergedColumnParallelLinear (#38517)
Signed-off-by: loukang <loukang@xiaohongshu.com>
This commit is contained in:
@@ -910,7 +910,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||
if isinstance(loaded_shard_id, tuple):
|
||||
for idx in loaded_shard_id:
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight, shard_id=idx
|
||||
)
|
||||
else:
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight, shard_id=0
|
||||
)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user