[torchao] fix safetensors for sharding (#28169)
Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
@@ -279,7 +279,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
if (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and quant_config.is_checkpoint_torchao_serialized
|
||||
and torchao_version_at_least("0.14.0")
|
||||
and torchao_version_at_least("0.15.0")
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
|
||||
|
||||
@@ -595,6 +595,9 @@ def safetensors_weights_iterator(
|
||||
if safetensors_load_strategy == "eager":
|
||||
loading_desc += " (eager)"
|
||||
|
||||
state_dict = {}
|
||||
leftover_state_dict: dict[str, torch.Tensor] = {}
|
||||
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc=loading_desc,
|
||||
@@ -606,9 +609,11 @@ def safetensors_weights_iterator(
|
||||
state_dict = load(f.read())
|
||||
yield from state_dict.items()
|
||||
elif safetensors_load_strategy == "torchao":
|
||||
if not torchao_version_at_least("0.14.0"):
|
||||
# we can't load flattened torchao tensor subclasses directly into the model
|
||||
# instead we reconstruct the subclasses here before returning
|
||||
if not torchao_version_at_least("0.15.0"):
|
||||
raise ValueError(
|
||||
"Please use torchao version >= 0.14.0 \
|
||||
"Please use torchao version >= 0.15.0 \
|
||||
to load torchao safetensors checkpoint"
|
||||
)
|
||||
from torchao.prototype.safetensors.safetensors_support import (
|
||||
@@ -616,12 +621,20 @@ def safetensors_weights_iterator(
|
||||
)
|
||||
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
state_dict = {}
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
state_dict[name] = f.get_tensor(name)
|
||||
|
||||
# update with leftover tensor data from previous iteration, if any
|
||||
state_dict.update(leftover_state_dict)
|
||||
metadata = f.metadata()
|
||||
updated_state_dict = unflatten_tensor_state_dict(state_dict, metadata)
|
||||
yield from updated_state_dict.items()
|
||||
# due to sharded checkpoints, we are not guaranteed that we have all
|
||||
# tensor subclass data on one file
|
||||
# state_dict has the leftover data from this step and we wait for
|
||||
# missing information to be provided in a future iteration
|
||||
unflattened_state_dict, leftover_state_dict = (
|
||||
unflatten_tensor_state_dict(state_dict, metadata)
|
||||
)
|
||||
yield from unflattened_state_dict.items()
|
||||
else:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
|
||||
Reference in New Issue
Block a user