diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 24b2f61b8..e00a17a15 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -773,7 +773,9 @@ def multi_thread_safetensors_weights_iterator( return result with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files] + # Note to use generator here so we do not store all the loaded files in memory + # at the same time, which can cause OOM for large models. + futures = (executor.submit(_load_file, st_file) for st_file in hf_weights_files) futures_iter = tqdm( concurrent.futures.as_completed(futures), total=len(hf_weights_files), @@ -784,7 +786,9 @@ def multi_thread_safetensors_weights_iterator( for future in futures_iter: state_dict = future.result() - yield from state_dict.items() + del future + for key in list(state_dict): + yield key, state_dict.pop(key) def runai_safetensors_weights_iterator(