[Minor] Sort safetensors files to ensure deterministic loading order (#33491)
Signed-off-by: Lihao Ran <imlihao.ran@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import IO, Any
|
||||
import filelock
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load, load_file, safe_open, save_file
|
||||
@@ -143,6 +144,15 @@ def atomic_writer(
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
def _natural_sort_key(filepath: str) -> list:
|
||||
"""Natural sort key for filenames with numeric components, such as
|
||||
model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']"""
|
||||
return [
|
||||
int(s) if s.isdigit() else s
|
||||
for s in re.split(r"(\d+)", os.path.basename(filepath))
|
||||
]
|
||||
|
||||
|
||||
def maybe_download_from_modelscope(
|
||||
model: str,
|
||||
revision: str | None = None,
|
||||
@@ -682,9 +692,8 @@ def safetensors_weights_iterator(
|
||||
loading_desc += " (eager)"
|
||||
|
||||
leftover_state_dict: dict[str, torch.Tensor] = {}
|
||||
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
sorted(hf_weights_files, key=_natural_sort_key),
|
||||
desc=loading_desc,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
|
||||
Reference in New Issue
Block a user