[ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (#5921)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
@@ -383,8 +383,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
# If fp8 + scale, need to send to each shard.
|
||||
if fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
@@ -567,8 +572,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
# If fp8 + scale, need to send to each shard.
|
||||
if fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user