fix: Use iterator as not to store all the file loads in memory at once (#36149)

Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>
This commit is contained in:
Shaun Kotek
2026-03-09 05:25:21 +02:00
committed by GitHub
parent dcf8862fd4
commit 90512b2e8b

View File

@@ -773,7 +773,9 @@ def multi_thread_safetensors_weights_iterator(
return result
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
# Note to use generator here so we do not store all the loaded files in memory
# at the same time, which can cause OOM for large models.
futures = (executor.submit(_load_file, st_file) for st_file in hf_weights_files)
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
@@ -784,7 +786,9 @@ def multi_thread_safetensors_weights_iterator(
for future in futures_iter:
state_dict = future.result()
yield from state_dict.items()
del future
for key in list(state_dict):
yield key, state_dict.pop(key)
def runai_safetensors_weights_iterator(