fix: Use iterator as not to store all the file loads in memory at once (#36149)
Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>
This commit is contained in:
@@ -773,7 +773,9 @@ def multi_thread_safetensors_weights_iterator(
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
|
||||
# Note to use generator here so we do not store all the loaded files in memory
|
||||
# at the same time, which can cause OOM for large models.
|
||||
futures = (executor.submit(_load_file, st_file) for st_file in hf_weights_files)
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
@@ -784,7 +786,9 @@ def multi_thread_safetensors_weights_iterator(
|
||||
|
||||
for future in futures_iter:
|
||||
state_dict = future.result()
|
||||
yield from state_dict.items()
|
||||
del future
|
||||
for key in list(state_dict):
|
||||
yield key, state_dict.pop(key)
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
|
||||
Reference in New Issue
Block a user